diff --git a/nibabel/__init__.py b/nibabel/__init__.py index 779f6e8587..36310ca019 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -64,6 +64,7 @@ from .imageclasses import class_map, ext_map, all_image_classes from . import trackvis from . import mriutils +from . import streamlines # be friendly on systems with ancient numpy -- no tests, but at least # importable diff --git a/nibabel/externals/six.py b/nibabel/externals/six.py index eae31454ae..b23166effd 100644 --- a/nibabel/externals/six.py +++ b/nibabel/externals/six.py @@ -143,6 +143,7 @@ class _MovedItems(types.ModuleType): MovedAttribute("StringIO", "StringIO", "io"), MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), + MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), MovedModule("builtins", "__builtin__"), MovedModule("configparser", "ConfigParser"), diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py new file mode 100644 index 0000000000..6cb4bdd87a --- /dev/null +++ b/nibabel/streamlines/__init__.py @@ -0,0 +1,131 @@ +import os +from ..externals.six import string_types + +from .header import TractogramHeader +from .compact_list import CompactList +from .tractogram import Tractogram, LazyTractogram + +from nibabel.streamlines.trk import TrkFile +#from nibabel.streamlines.tck import TckFile +#from nibabel.streamlines.vtk import VtkFile + +# List of all supported formats +FORMATS = {".trk": TrkFile, + #".tck": TckFile, + #".vtk": VtkFile, + } + + +def is_supported(fileobj): + ''' Checks if the file-like object if supported by NiBabel. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header) + + Returns + ------- + is_supported : boolean + ''' + return detect_format(fileobj) is not None + + +def detect_format(fileobj): + ''' Returns the StreamlinesFile object guessed from the file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a tractogram file (and ready to read from the beginning of the + header) + + Returns + ------- + tractogram_file : ``TractogramFile`` class + Returns an instance of a `TractogramFile` class containing data and + metadata of the tractogram contained from `fileobj`. + ''' + for format in FORMATS.values(): + try: + if format.is_correct_format(fileobj): + return format + + except IOError: + pass + + if isinstance(fileobj, string_types): + _, ext = os.path.splitext(fileobj) + return FORMATS.get(ext.lower()) + + return None + + +def load(fileobj, lazy_load=False, ref=None): + ''' Loads streamlines from a file-like object in voxel space. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the beginning + of the streamlines file's header). + + lazy_load : boolean (optional) + Load streamlines in a lazy manner i.e. they will not be kept + in memory. + + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines will live in `fileobj`. + + Returns + ------- + tractogram_file : ``TractogramFile`` + Returns an instance of a `TractogramFile` class containing data and + metadata of the tractogram loaded from `fileobj`. + ''' + tractogram_file = detect_format(fileobj) + + if tractogram_file is None: + raise TypeError("Unknown format for 'fileobj': {}".format(fileobj)) + + return tractogram_file.load(fileobj, lazy_load=lazy_load) + + +def save(tractogram_file, filename): + ''' Saves a tractogram to a file. + + Parameters + ---------- + tractogram_file : ``TractogramFile`` object + Tractogram to be saved on disk. + + filename : str + Name of the file where the tractogram will be saved. The format will + be guessed from `filename`. + ''' + tractogram_file.save(filename) + + +def save_tractogram(tractogram, filename, **kwargs): + ''' Saves a tractogram to a file. + + Parameters + ---------- + tractogram : ``Tractogram`` object + Tractogram to be saved. + + filename : str + Name of the file where the tractogram will be saved. The format will + be guessed from `filename`. + ''' + tractogram_file_class = detect_format(filename) + + if tractogram_file_class is None: + raise TypeError("Unknown tractogram file format: '{}'".format(filename)) + + tractogram_file = tractogram_file_class(tractogram, **kwargs) + tractogram_file.save(filename) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py new file mode 100644 index 0000000000..8fb761f312 --- /dev/null +++ b/nibabel/streamlines/compact_list.py @@ -0,0 +1,221 @@ +import numpy as np + + +class CompactList(object): + """ Class for compacting list of ndarrays with matching shape except for + the first dimension. + """ + + BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. + + def __init__(self, iterable=None): + """ + Parameters + ---------- + iterable : iterable (optional) + If specified, create a ``CompactList`` object initialized from + iterable's items. Otherwise, create an empty ``CompactList``. + + Notes + ----- + If `iterable` is a ``CompactList`` object, a view is returned and no + memory is allocated. For an actual copy use the `.copy()` method. + """ + # Create new empty `CompactList` object. + self._data = None + self._offsets = [] + self._lengths = [] + + if isinstance(iterable, CompactList): + # Create a view. + self._data = iterable._data + self._offsets = iterable._offsets + self._lengths = iterable._lengths + + elif iterable is not None: + # Initialize the `CompactList` object from iterable's item. + offset = 0 + for i, e in enumerate(iterable): + e = np.asarray(e) + if i == 0: + new_shape = (CompactList.BUFFER_SIZE,) + e.shape[1:] + self._data = np.empty(new_shape, dtype=e.dtype) + + end = offset + len(e) + if end >= len(self._data): + # Resize needed, adding `len(e)` new items plus some buffer. + nb_points = len(self._data) + nb_points += len(e) + CompactList.BUFFER_SIZE + self._data.resize((nb_points,) + self.shape) + + self._offsets.append(offset) + self._lengths.append(len(e)) + self._data[offset:offset+len(e)] = e + offset += len(e) + + # Clear unused memory. + if self._data is not None: + self._data.resize((offset,) + self.shape) + + @property + def shape(self): + """ Returns the matching shape of the elements in this compact list. """ + if self._data is None: + return None + + return self._data.shape[1:] + + def append(self, element): + """ Appends `element` to this compact list. + + Parameters + ---------- + element : ndarray + Element to append. The shape must match already inserted elements + shape except for the first dimension. + + Notes + ----- + If you need to add multiple elements you should consider + `CompactList.extend`. + """ + if self._data is None: + self._data = np.asarray(element).copy() + self._offsets.append(0) + self._lengths.append(len(element)) + return + + if element.shape[1:] != self.shape: + raise ValueError("All dimensions, except the first one," + " must match exactly") + + self._offsets.append(len(self._data)) + self._lengths.append(len(element)) + self._data = np.append(self._data, element, axis=0) + + def extend(self, elements): + """ Appends all `elements` to this compact list. + + Parameters + ---------- + elements : list of ndarrays, ``CompactList`` object + Elements to append. The shape must match already inserted elements + shape except for the first dimension. + + """ + if self._data is None: + elem = np.asarray(elements[0]) + self._data = np.zeros((0, elem.shape[1]), dtype=elem.dtype) + + next_offset = self._data.shape[0] + + if isinstance(elements, CompactList): + self._data.resize((self._data.shape[0]+sum(elements._lengths), + self._data.shape[1])) + + for offset, length in zip(elements._offsets, elements._lengths): + self._offsets.append(next_offset) + self._lengths.append(length) + self._data[next_offset:next_offset+length] = elements._data[offset:offset+length] + next_offset += length + + else: + self._data = np.concatenate([self._data] + list(elements), axis=0) + lengths = list(map(len, elements)) + self._lengths.extend(lengths) + self._offsets.extend(np.cumsum([next_offset] + lengths).tolist()[:-1]) + + def copy(self): + """ Creates a copy of this ``CompactList`` object. """ + # We do not simply deepcopy this object since we might have a chance + # to use less memory. For example, if the compact list being copied + # is the result of a slicing operation on a compact list. + clist = CompactList() + total_lengths = np.sum(self._lengths) + clist._data = np.empty((total_lengths,) + self._data.shape[1:], + dtype=self._data.dtype) + + idx = 0 + for offset, length in zip(self._offsets, self._lengths): + clist._offsets.append(idx) + clist._lengths.append(length) + clist._data[idx:idx+length] = self._data[offset:offset+length] + idx += length + + return clist + + def __getitem__(self, idx): + """ Gets element(s) through indexing. + + Parameters + ---------- + idx : int, slice or list + Index of the element(s) to get. + + Returns + ------- + ndarray object(s) + When `idx` is an int, returns a single ndarray. + When `idx` is either a slice or a list, returns a list of ndarrays. + """ + if isinstance(idx, int) or isinstance(idx, np.integer): + start = self._offsets[idx] + return self._data[start:start+self._lengths[idx]] + + elif isinstance(idx, slice): + clist = CompactList() + clist._data = self._data + clist._offsets = self._offsets[idx] + clist._lengths = self._lengths[idx] + return clist + + elif isinstance(idx, list): + clist = CompactList() + clist._data = self._data + clist._offsets = [self._offsets[i] for i in idx] + clist._lengths = [self._lengths[i] for i in idx] + return clist + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + clist = CompactList() + clist._data = self._data + clist._offsets = [self._offsets[i] + for i, take_it in enumerate(idx) if take_it] + clist._lengths = [self._lengths[i] + for i, take_it in enumerate(idx) if take_it] + return clist + + raise TypeError("Index must be either an int, a slice, a list of int" + " or a ndarray of bool! Not " + str(type(idx))) + + def __iter__(self): + if len(self._lengths) != len(self._offsets): + raise ValueError("CompactList object corrupted:" + " len(self._lengths) != len(self._offsets)") + + for offset, lengths in zip(self._offsets, self._lengths): + yield self._data[offset: offset+lengths] + + def __len__(self): + return len(self._offsets) + + def __repr__(self): + return repr(list(self)) + + +def save_compact_list(filename, clist): + """ Saves a `CompactList` object to a .npz file. """ + np.savez(filename, + data=clist._data, + offsets=clist._offsets, + lengths=clist._lengths) + + +def load_compact_list(filename): + """ Loads a `CompactList` object from a .npz file. """ + content = np.load(filename) + clist = CompactList() + clist._data = content["data"] + clist._offsets = content["offsets"].tolist() + clist._lengths = content["lengths"].tolist() + return clist diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py new file mode 100644 index 0000000000..3fac6952bd --- /dev/null +++ b/nibabel/streamlines/header.py @@ -0,0 +1,141 @@ +import copy +import numpy as np +from nibabel.orientations import aff2axcodes +from nibabel.externals import OrderedDict + + +class Field: + """ Header fields common to multiple streamlines file formats. + + In IPython, use `nibabel.streamlines.Field??` to list them. + """ + NB_STREAMLINES = "nb_streamlines" + STEP_SIZE = "step_size" + METHOD = "method" + NB_SCALARS_PER_POINT = "nb_scalars_per_point" + NB_PROPERTIES_PER_STREAMLINE = "nb_properties_per_streamline" + NB_POINTS = "nb_points" + VOXEL_SIZES = "voxel_sizes" + DIMENSIONS = "dimensions" + MAGIC_NUMBER = "magic_number" + ORIGIN = "origin" + VOXEL_TO_RASMM = "voxel_to_rasmm" + VOXEL_ORDER = "voxel_order" + ENDIAN = "endian" + + +class TractogramHeader(object): + def __init__(self, hdr=None): + self._nb_streamlines = None + self._nb_scalars_per_point = None + self._nb_properties_per_streamline = None + self._to_world_space = np.eye(4) + self.extra = OrderedDict() + + if type(hdr) is dict: + if Field.NB_POINTS in hdr: + self.nb_streamlines = hdr[Field.NB_POINTS] + + if Field.NB_SCALARS_PER_POINT in hdr: + self.nb_scalars_per_point = hdr[Field.NB_SCALARS_PER_POINT] + + if Field.NB_PROPERTIES_PER_STREAMLINE in hdr: + self.nb_properties_per_streamline = hdr[Field.NB_PROPERTIES_PER_STREAMLINE] + + if Field.VOXEL_TO_RASMM in hdr: + self.to_world_space = hdr[Field.VOXEL_TO_RASMM] + + elif type(hdr) is TractogramHeader: + self._nb_streamlines = hdr._nb_streamlines + self._nb_scalars_per_point = hdr._nb_scalars_per_point + self._nb_properties_per_streamline = hdr._nb_properties_per_streamline + self._to_world_space = hdr._to_world_space + self.extra = copy.deepcopy(hdr.extra) + + @property + def to_world_space(self): + return self._to_world_space + + @to_world_space.setter + def to_world_space(self, value): + self._to_world_space = np.asarray(value, dtype=np.float32) + + @property + def voxel_sizes(self): + """ Get voxel sizes from to_world_space. """ + return np.sqrt(np.sum(self.to_world_space[:3, :3]**2, axis=0)) + + @voxel_sizes.setter + def voxel_sizes(self, value): + scaling = np.r_[np.array(value), [1]] + old_scaling = np.r_[np.array(self.voxel_sizes), [1]] + # Remove old scaling and apply new one + self.to_world_space = np.dot(np.diag(scaling/old_scaling), self.to_world_space) + + @property + def voxel_order(self): + """ Get voxel order from to_world_space. """ + return "".join(aff2axcodes(self.to_world_space)) + + @property + def nb_streamlines(self): + return self._nb_streamlines + + @nb_streamlines.setter + def nb_streamlines(self, value): + self._nb_streamlines = int(value) + + @property + def nb_scalars_per_point(self): + return self._nb_scalars_per_point + + @nb_scalars_per_point.setter + def nb_scalars_per_point(self, value): + self._nb_scalars_per_point = int(value) + + @property + def nb_properties_per_streamline(self): + return self._nb_properties_per_streamline + + @nb_properties_per_streamline.setter + def nb_properties_per_streamline(self, value): + self._nb_properties_per_streamline = int(value) + + @property + def extra(self): + return self._extra + + @extra.setter + def extra(self, value): + self._extra = OrderedDict(value) + + def copy(self): + header = TractogramHeader() + header._nb_streamlines = self.nb_streamlines + header.nb_scalars_per_point = self.nb_scalars_per_point + header.nb_properties_per_streamline = self.nb_properties_per_streamline + header.to_world_space = self.to_world_space.copy() + header.extra = copy.deepcopy(self.extra) + return header + + def __eq__(self, other): + return (np.allclose(self.to_world_space, other.to_world_space) and + self.nb_streamlines == other.nb_streamlines and + self.nb_scalars_per_point == other.nb_scalars_per_point and + self.nb_properties_per_streamline == other.nb_properties_per_streamline and + repr(self.extra) == repr(other.extra)) # Not the robust way, but will do! + + def __repr__(self): + txt = "Header{\n" + txt += "nb_streamlines: " + repr(self.nb_streamlines) + '\n' + txt += "nb_scalars_per_point: " + repr(self.nb_scalars_per_point) + '\n' + txt += "nb_properties_per_streamline: " + repr(self.nb_properties_per_streamline) + '\n' + txt += "to_world_space: " + repr(self.to_world_space) + '\n' + txt += "voxel_sizes: " + repr(self.voxel_sizes) + '\n' + + txt += "Extra fields: {\n" + for key in sorted(self.extra.keys()): + txt += " " + repr(key) + ": " + repr(self.extra[key]) + "\n" + + txt += " }\n" + return txt + "}" diff --git a/nibabel/streamlines/tests/__init__.py b/nibabel/streamlines/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nibabel/streamlines/tests/data/complex.trk b/nibabel/streamlines/tests/data/complex.trk new file mode 100644 index 0000000000..0a874ea6e7 Binary files /dev/null and b/nibabel/streamlines/tests/data/complex.trk differ diff --git a/nibabel/streamlines/tests/data/empty.trk b/nibabel/streamlines/tests/data/empty.trk new file mode 100644 index 0000000000..fbe0871807 Binary files /dev/null and b/nibabel/streamlines/tests/data/empty.trk differ diff --git a/nibabel/streamlines/tests/data/gen_standard.py b/nibabel/streamlines/tests/data/gen_standard.py new file mode 100644 index 0000000000..63b4173602 --- /dev/null +++ b/nibabel/streamlines/tests/data/gen_standard.py @@ -0,0 +1,79 @@ +import numpy as np +import nibabel as nib + +from nibabel.streamlines import FORMATS +from nibabel.streamlines.header import Field + + +def mark_the_spot(mask): + """ Marks every nonzero voxel using streamlines to form a 3D 'X' inside. + + Generates streamlines forming a 3D 'X' inside every nonzero voxel. + + Parameters + ---------- + mask : ndarray + Mask containing the spots to be marked. + + Returns + ------- + list of ndarrays + All streamlines needed to mark every nonzero voxel in the `mask`. + """ + def _gen_straight_streamline(start, end, steps=3): + coords = [] + for s, e in zip(start, end): + coords.append(np.linspace(s, e, steps)) + + return np.array(coords).T + + # Generate a 3D 'X' template fitting inside the voxel centered at (0,0,0). + X = [] + X.append(_gen_straight_streamline((-0.5, -0.5, -0.5), (0.5, 0.5, 0.5))) + X.append(_gen_straight_streamline((-0.5, 0.5, -0.5), (0.5, -0.5, 0.5))) + X.append(_gen_straight_streamline((-0.5, 0.5, 0.5), (0.5, -0.5, -0.5))) + X.append(_gen_straight_streamline((-0.5, -0.5, 0.5), (0.5, 0.5, -0.5))) + + # Get the coordinates of voxels 'on' in the mask. + coords = np.array(zip(*np.where(mask))) + + streamlines = [] + for c in coords: + for line in X: + streamlines.append((line + c) * voxel_size) + + return streamlines + + +rng = np.random.RandomState(42) + +width = 4 # Coronal +height = 5 # Sagittal +depth = 7 # Axial + +voxel_size = np.array((1., 3., 2.)) + +# Generate a random mask with voxel order RAS+. +mask = rng.rand(width, height, depth) > 0.8 +mask = (255*mask).astype(np.uint8) + +# Build tractogram +streamlines = mark_the_spot(mask) +tractogram = nib.streamlines.Tractogram(streamlines) + +# Build header +affine = np.eye(4) +affine[range(3), range(3)] = voxel_size +header = {Field.DIMENSIONS: (width, height, depth), + Field.VOXEL_SIZES: voxel_size, + Field.VOXEL_TO_RASMM: affine, + Field.VOXEL_ORDER: 'RAS'} + +# Save the standard mask. +nii = nib.Nifti1Image(mask, affine=affine) +nib.save(nii, "standard.nii.gz") + +# Save the standard tractogram in every available file format. +for ext, cls in FORMATS.items(): + tfile = cls(tractogram, header) + nib.streamlines.save(tfile, "standard" + ext) diff --git a/nibabel/streamlines/tests/data/simple.trk b/nibabel/streamlines/tests/data/simple.trk new file mode 100644 index 0000000000..df601e29a7 Binary files /dev/null and b/nibabel/streamlines/tests/data/simple.trk differ diff --git a/nibabel/streamlines/tests/data/standard.LPS.trk b/nibabel/streamlines/tests/data/standard.LPS.trk new file mode 100644 index 0000000000..ebda71bdb8 Binary files /dev/null and b/nibabel/streamlines/tests/data/standard.LPS.trk differ diff --git a/nibabel/streamlines/tests/data/standard.nii.gz b/nibabel/streamlines/tests/data/standard.nii.gz new file mode 100644 index 0000000000..98bb31a778 Binary files /dev/null and b/nibabel/streamlines/tests/data/standard.nii.gz differ diff --git a/nibabel/streamlines/tests/data/standard.trk b/nibabel/streamlines/tests/data/standard.trk new file mode 100644 index 0000000000..01ea01744a Binary files /dev/null and b/nibabel/streamlines/tests/data/standard.trk differ diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py new file mode 100644 index 0000000000..188102ce85 --- /dev/null +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -0,0 +1,286 @@ +import os +import unittest +import tempfile +import numpy as np + +from nose.tools import assert_equal, assert_raises, assert_true +from nibabel.testing import assert_arrays_equal +from numpy.testing import assert_array_equal +from nibabel.externals.six.moves import zip, zip_longest + +from ..compact_list import (CompactList, + load_compact_list, + save_compact_list) + + +class TestCompactList(unittest.TestCase): + + def setUp(self): + rng = np.random.RandomState(42) + self.data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + self.lengths = list(map(len, self.data)) + self.clist = CompactList(self.data) + + def test_creating_empty_compactlist(self): + clist = CompactList() + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_list(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = list(map(len, data)) + + clist = CompactList(data) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + + # Empty list + clist = CompactList([]) + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + # Force CompactList constructor to use buffering. + old_buffer_size = CompactList.BUFFER_SIZE + CompactList.BUFFER_SIZE = 1 + clist = CompactList(data) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + CompactList.BUFFER_SIZE = old_buffer_size + + def test_creating_compactlist_from_generator(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = list(map(len, data)) + + gen = (e for e in data) + clist = CompactList(gen) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + + # Already consumed generator + clist = CompactList(gen) + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_compact_list(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = list(map(len, data)) + + clist = CompactList(data) + clist2 = CompactList(clist) + assert_equal(len(clist2), len(data)) + assert_equal(len(clist2._offsets), len(data)) + assert_equal(len(clist2._lengths), len(data)) + assert_equal(clist2._data.shape[0], sum(lengths)) + assert_equal(clist2._data.shape[1], 3) + assert_equal(clist2._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist2._lengths, lengths) + assert_equal(clist2.shape, data[0].shape[1:]) + + def test_compactlist_iter(self): + for e, d in zip(self.clist, self.data): + assert_array_equal(e, d) + + # Try iterate through a corrupted CompactList object. + clist = self.clist.copy() + clist._lengths = clist._lengths[::2] + assert_raises(ValueError, list, clist) + + def test_compactlist_copy(self): + clist = self.clist.copy() + assert_array_equal(clist._data, self.clist._data) + assert_true(clist._data is not self.clist._data) + assert_array_equal(clist._offsets, self.clist._offsets) + assert_true(clist._offsets is not self.clist._offsets) + assert_array_equal(clist._lengths, self.clist._lengths) + assert_true(clist._lengths is not self.clist._lengths) + + assert_equal(clist.shape, self.clist.shape) + + # When taking a copy of a `CompactList` generated by slicing. + # Only needed data should be kept. + clist = self.clist[::2].copy() + + assert_true(clist._data.shape[0] < self.clist._data.shape[0]) + assert_true(len(clist) < len(self.clist)) + assert_true(clist._data is not self.clist._data) + + def test_compactlist_append(self): + # Maybe not necessary if `self.setUp` is always called before a + # test method, anyways create a copy just in case. + clist = self.clist.copy() + + rng = np.random.RandomState(1234) + element = rng.rand(rng.randint(10, 50), *self.clist.shape) + clist.append(element) + assert_equal(len(clist), len(self.clist)+1) + assert_equal(clist._offsets[-1], len(self.clist._data)) + assert_equal(clist._lengths[-1], len(element)) + assert_array_equal(clist._data[-len(element):], element) + + # Append with different shape. + element = rng.rand(rng.randint(10, 50), 42) + assert_raises(ValueError, clist.append, element) + + # Append to an empty CompactList. + clist = CompactList() + rng = np.random.RandomState(1234) + shape = (2, 3, 4) + element = rng.rand(rng.randint(10, 50), *shape) + clist.append(element) + + assert_equal(len(clist), 1) + assert_equal(clist._offsets[-1], 0) + assert_equal(clist._lengths[-1], len(element)) + assert_array_equal(clist._data, element) + assert_equal(clist.shape, shape) + + def test_compactlist_extend(self): + # Maybe not necessary if `self.setUp` is always called before a + # test method, anyways create a copy just in case. + clist = self.clist.copy() + + rng = np.random.RandomState(1234) + shape = self.clist.shape + new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(10)] + lengths = list(map(len, new_data)) + clist.extend(new_data) + assert_equal(len(clist), len(self.clist)+len(new_data)) + assert_array_equal(clist._offsets[-len(new_data):], + len(self.clist._data) + np.cumsum([0] + lengths[:-1])) + + assert_equal(clist._lengths[-len(new_data):], lengths) + assert_array_equal(clist._data[-sum(lengths):], + np.concatenate(new_data, axis=0)) + + # Extend with another `CompactList` object. + clist = self.clist.copy() + new_clist = CompactList(new_data) + clist.extend(new_clist) + assert_equal(len(clist), len(self.clist)+len(new_clist)) + assert_array_equal(clist._offsets[-len(new_clist):], + len(self.clist._data) + np.cumsum([0] + lengths[:-1])) + + assert_equal(clist._lengths[-len(new_clist):], lengths) + assert_array_equal(clist._data[-sum(lengths):], new_clist._data) + + # Extend with another `CompactList` object that is a view (e.g. been sliced). + # Need to make sure we extend only the data we need. + clist = self.clist.copy() + new_clist = CompactList(new_data)[::2] + clist.extend(new_clist) + assert_equal(len(clist), len(self.clist)+len(new_clist)) + assert_equal(len(clist._data), len(self.clist._data)+sum(new_clist._lengths)) + assert_array_equal(clist._offsets[-len(new_clist):], + len(self.clist._data) + np.cumsum([0] + new_clist._lengths[:-1])) + + assert_equal(clist._lengths[-len(new_clist):], lengths[::2]) + assert_array_equal(clist._data[-sum(new_clist._lengths):], new_clist.copy()._data) + assert_arrays_equal(clist[-len(new_clist):], new_clist) + + # Test extending an empty CompactList + clist = CompactList() + new_clist = CompactList(new_data) + clist.extend(new_clist) + assert_equal(len(clist), len(new_clist)) + assert_array_equal(clist._offsets, new_clist._offsets) + assert_array_equal(clist._lengths, new_clist._lengths) + assert_array_equal(clist._data, new_clist._data) + + + def test_compactlist_getitem(self): + # Get one item + for i, e in enumerate(self.clist): + assert_array_equal(self.clist[i], e) + + # Get multiple items (this will create a view). + clist_view = self.clist[list(range(len(self.clist)))] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_true(clist_view._offsets is not self.clist._offsets) + assert_true(clist_view._lengths is not self.clist._lengths) + assert_array_equal(clist_view._offsets, self.clist._offsets) + assert_array_equal(clist_view._lengths, self.clist._lengths) + for e1, e2 in zip_longest(clist_view, self.clist): + assert_array_equal(e1, e2) + + # Get slice (this will create a view). + clist_view = self.clist[::2] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_array_equal(clist_view._offsets, self.clist._offsets[::2]) + assert_array_equal(clist_view._lengths, self.clist._lengths[::2]) + for i, e in enumerate(clist_view): + assert_array_equal(e, self.clist[i*2]) + + # Use advance indexing with ndarray of data type bool. + idx = np.array([False, True, True, False, True]) + clist_view = self.clist[idx] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_array_equal(clist_view._offsets, + np.asarray(self.clist._offsets)[idx]) + assert_array_equal(clist_view._lengths, + np.asarray(self.clist._lengths)[idx]) + assert_array_equal(clist_view[0], self.clist[1]) + assert_array_equal(clist_view[1], self.clist[2]) + assert_array_equal(clist_view[2], self.clist[4]) + + # Test invalid indexing + assert_raises(TypeError, self.clist.__getitem__, 'abc') + + def test_compactlist_repr(self): + # Test that calling repr on a CompactList object is not falling. + repr(self.clist) + + +def test_save_and_load_compact_list(): + + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + clist = CompactList() + save_compact_list(f, clist) + f.seek(0, os.SEEK_SET) + loaded_clist = load_compact_list(f) + assert_array_equal(loaded_clist._data, clist._data) + assert_array_equal(loaded_clist._offsets, clist._offsets) + assert_array_equal(loaded_clist._lengths, clist._lengths) + + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + clist = CompactList(data) + save_compact_list(f, clist) + f.seek(0, os.SEEK_SET) + loaded_clist = load_compact_list(f) + assert_array_equal(loaded_clist._data, clist._data) + assert_array_equal(loaded_clist._offsets, clist._offsets) + assert_array_equal(loaded_clist._lengths, clist._lengths) diff --git a/nibabel/streamlines/tests/test_header.py b/nibabel/streamlines/tests/test_header.py new file mode 100644 index 0000000000..398195f615 --- /dev/null +++ b/nibabel/streamlines/tests/test_header.py @@ -0,0 +1,37 @@ +import numpy as np + +from nose.tools import assert_equal, assert_true +from numpy.testing import assert_array_equal + +from nibabel.streamlines.header import TractogramHeader + + +def test_streamlines_header(): + header = TractogramHeader() + assert_true(header.nb_streamlines is None) + assert_true(header.nb_scalars_per_point is None) + assert_true(header.nb_properties_per_streamline is None) + assert_array_equal(header.voxel_sizes, (1, 1, 1)) + assert_array_equal(header.to_world_space, np.eye(4)) + assert_equal(header.extra, {}) + + # Modify simple attributes + header.nb_streamlines = 1 + header.nb_scalars_per_point = 2 + header.nb_properties_per_streamline = 3 + assert_equal(header.nb_streamlines, 1) + assert_equal(header.nb_scalars_per_point, 2) + assert_equal(header.nb_properties_per_streamline, 3) + + # Modifying voxel_sizes should be reflected in to_world_space + header.voxel_sizes = (2, 3, 4) + assert_array_equal(header.voxel_sizes, (2, 3, 4)) + assert_array_equal(np.diag(header.to_world_space), (2, 3, 4, 1)) + + # Modifying scaling of to_world_space should be reflected in voxel_sizes + header.to_world_space = np.diag([4, 3, 2, 1]) + assert_array_equal(header.voxel_sizes, (4, 3, 2)) + assert_array_equal(header.to_world_space, np.diag([4, 3, 2, 1])) + + # Test that we can run __repr__ without error. + repr(header) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py new file mode 100644 index 0000000000..c21c688989 --- /dev/null +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -0,0 +1,249 @@ +import os +import unittest +import tempfile +import numpy as np + +from os.path import join as pjoin + +import nibabel as nib +from nibabel.externals.six import BytesIO + +from nibabel.testing import clear_and_catch_warnings +from nibabel.testing import assert_arrays_equal, check_iteration +from nose.tools import assert_equal, assert_raises, assert_true, assert_false + +from .test_tractogram import assert_tractogram_equal +from ..tractogram import Tractogram, LazyTractogram +from ..tractogram_file import TractogramFile +from ..tractogram import UsageWarning +from .. import trk + +DATA_PATH = pjoin(os.path.dirname(__file__), 'data') + + +def test_is_supported(): + # Emtpy file/string + f = BytesIO() + assert_false(nib.streamlines.is_supported(f)) + assert_false(nib.streamlines.is_supported("")) + + # Valid file without extension + for tfile_cls in nib.streamlines.FORMATS.values(): + f = BytesIO() + f.write(tfile_cls.get_magic_number()) + f.seek(0, os.SEEK_SET) + assert_true(nib.streamlines.is_supported(f)) + + # Wrong extension but right magic number + for tfile_cls in nib.streamlines.FORMATS.values(): + with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: + f.write(tfile_cls.get_magic_number()) + f.seek(0, os.SEEK_SET) + assert_true(nib.streamlines.is_supported(f)) + + # Good extension but wrong magic number + for ext, tfile_cls in nib.streamlines.FORMATS.items(): + with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: + f.write(b"pass") + f.seek(0, os.SEEK_SET) + assert_false(nib.streamlines.is_supported(f)) + + # Wrong extension, string only + f = "my_tractogram.asd" + assert_false(nib.streamlines.is_supported(f)) + + # Good extension, string only + for ext, tfile_cls in nib.streamlines.FORMATS.items(): + f = "my_tractogram" + ext + assert_true(nib.streamlines.is_supported(f)) + + +def test_detect_format(): + # Emtpy file/string + f = BytesIO() + assert_true(nib.streamlines.detect_format(f) is None) + assert_true(nib.streamlines.detect_format("") is None) + + # Valid file without extension + for tfile_cls in nib.streamlines.FORMATS.values(): + f = BytesIO() + f.write(tfile_cls.get_magic_number()) + f.seek(0, os.SEEK_SET) + assert_true(nib.streamlines.detect_format(f) is tfile_cls) + + # Wrong extension but right magic number + for tfile_cls in nib.streamlines.FORMATS.values(): + with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: + f.write(tfile_cls.get_magic_number()) + f.seek(0, os.SEEK_SET) + assert_true(nib.streamlines.detect_format(f) is tfile_cls) + + # Good extension but wrong magic number + for ext, tfile_cls in nib.streamlines.FORMATS.items(): + with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: + f.write(b"pass") + f.seek(0, os.SEEK_SET) + assert_true(nib.streamlines.detect_format(f) is None) + + # Wrong extension, string only + f = "my_tractogram.asd" + assert_true(nib.streamlines.detect_format(f) is None) + + # Good extension, string only + for ext, tfile_cls in nib.streamlines.FORMATS.items(): + f = "my_tractogram" + ext + assert_equal(nib.streamlines.detect_format(f), tfile_cls) + + # Extension should not be case-sensitive. + for ext, tfile_cls in nib.streamlines.FORMATS.items(): + f = "my_tractogram" + ext.upper() + assert_true(nib.streamlines.detect_format(f) is tfile_cls) + + +class TestLoadSave(unittest.TestCase): + def setUp(self): + self.empty_filenames = [pjoin(DATA_PATH, "empty" + ext) + for ext in nib.streamlines.FORMATS.keys()] + self.simple_filenames = [pjoin(DATA_PATH, "simple" + ext) + for ext in nib.streamlines.FORMATS.keys()] + self.complex_filenames = [pjoin(DATA_PATH, "complex" + ext) + for ext in nib.streamlines.FORMATS.keys()] + + self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + self.fa = [np.array([[0.2]], dtype="f4"), + np.array([[0.3], + [0.4]], dtype="f4"), + np.array([[0.5], + [0.6], + [0.6], + [0.7], + [0.8]], dtype="f4")] + + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + self.mean_curvature = [np.array([1.11], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11], dtype="f4")] + + self.mean_torsion = [np.array([1.22], dtype="f4"), + np.array([2.22], dtype="f4"), + np.array([3.22], dtype="f4")] + + self.mean_colors = [np.array([1, 0, 0], dtype="f4"), + np.array([0, 1, 0], dtype="f4"), + np.array([0, 0, 1], dtype="f4")] + + self.data_per_point = {'colors': self.colors, + 'fa': self.fa} + self.data_per_streamline = {'mean_curvature': self.mean_curvature, + 'mean_torsion': self.mean_torsion, + 'mean_colors': self.mean_colors} + + self.empty_tractogram = Tractogram() + self.simple_tractogram = Tractogram(self.streamlines) + self.complex_tractogram = Tractogram(self.streamlines, + self.data_per_streamline, + self.data_per_point) + + def test_load_empty_file(self): + for lazy_load in [False, True]: + for empty_filename in self.empty_filenames: + tfile = nib.streamlines.load(empty_filename, + lazy_load=lazy_load) + assert_true(isinstance(tfile, TractogramFile)) + + if lazy_load: + assert_true(type(tfile.tractogram), Tractogram) + else: + assert_true(type(tfile.tractogram), LazyTractogram) + + assert_tractogram_equal(tfile.tractogram, + self.empty_tractogram) + + def test_load_simple_file(self): + for lazy_load in [False, True]: + for simple_filename in self.simple_filenames: + tfile = nib.streamlines.load(simple_filename, + lazy_load=lazy_load) + assert_true(isinstance(tfile, TractogramFile)) + + if lazy_load: + assert_true(type(tfile.tractogram), Tractogram) + else: + assert_true(type(tfile.tractogram), LazyTractogram) + + assert_tractogram_equal(tfile.tractogram, + self.simple_tractogram) + + def test_load_complex_file(self): + for lazy_load in [False, True]: + for complex_filename in self.complex_filenames: + tfile = nib.streamlines.load(complex_filename, + lazy_load=lazy_load) + assert_true(isinstance(tfile, TractogramFile)) + + if lazy_load: + assert_true(type(tfile.tractogram), Tractogram) + else: + assert_true(type(tfile.tractogram), LazyTractogram) + + tractogram = Tractogram(self.streamlines) + + if tfile.support_data_per_point(): + tractogram.data_per_point = self.data_per_point + + if tfile.support_data_per_streamline(): + tractogram.data_per_streamline = self.data_per_streamline + + assert_tractogram_equal(tfile.tractogram, + tractogram) + + def test_save_empty_file(self): + tractogram = Tractogram() + for ext, cls in nib.streamlines.FORMATS.items(): + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + nib.streamlines.save_tractogram(tractogram, f.name) + tfile = nib.streamlines.load(f, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) + + def test_save_simple_file(self): + tractogram = Tractogram(self.streamlines) + for ext, cls in nib.streamlines.FORMATS.items(): + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + nib.streamlines.save_tractogram(tractogram, f.name) + tfile = nib.streamlines.load(f, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) + + def test_save_complex_file(self): + complex_tractogram = Tractogram(self.streamlines, + self.data_per_streamline, + self.data_per_point) + + for ext, cls in nib.streamlines.FORMATS.items(): + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + with clear_and_catch_warnings(record=True, modules=[trk]) as w: + nib.streamlines.save_tractogram(complex_tractogram, f.name) + + # If streamlines format does not support saving data per + # point or data per streamline, a warning message should + # be issued. + if not (cls.support_data_per_point() + and cls.support_data_per_streamline()): + assert_equal(len(w), 1) + assert_true(issubclass(w[0].category, UsageWarning)) + + tractogram = Tractogram(self.streamlines) + + if cls.support_data_per_point(): + tractogram.data_per_point = self.data_per_point + + if cls.support_data_per_streamline(): + tractogram.data_per_streamline = self.data_per_streamline + + tfile = nib.streamlines.load(f, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py new file mode 100644 index 0000000000..953765bc9b --- /dev/null +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -0,0 +1,478 @@ +import unittest +import numpy as np +import warnings + +from nibabel.testing import assert_arrays_equal, check_iteration +from nibabel.testing import suppress_warnings, clear_and_catch_warnings +from nose.tools import assert_equal, assert_raises, assert_true +from numpy.testing import assert_array_equal, assert_array_almost_equal +from nibabel.externals.six.moves import zip + +from .. import tractogram as module_tractogram +from ..tractogram import UsageWarning +from ..tractogram import TractogramItem, Tractogram, LazyTractogram + + +def assert_tractogram_equal(t1, t2): + assert_true(check_iteration(t1)) + assert_equal(len(t1), len(t2)) + assert_arrays_equal(t1.streamlines, t2.streamlines) + + assert_equal(len(t1.data_per_streamline), len(t2.data_per_streamline)) + for key in t1.data_per_streamline.keys(): + assert_arrays_equal(t1.data_per_streamline[key], + t2.data_per_streamline[key]) + + assert_equal(len(t1.data_per_point), len(t2.data_per_point)) + for key in t1.data_per_point.keys(): + assert_arrays_equal(t1.data_per_point[key], + t2.data_per_point[key]) + + +class TestTractogramItem(unittest.TestCase): + + def test_creating_tractogram_item(self): + rng = np.random.RandomState(42) + streamline = rng.rand(rng.randint(10, 50), 3) + colors = rng.rand(len(streamline), 3) + mean_curvature = 1.11 + mean_color = np.array([0, 1, 0], dtype="f4") + + data_for_streamline = {"mean_curvature": mean_curvature, + "mean_color": mean_color} + + data_for_points = {"colors": colors} + + # Create a tractogram item with a streamline, data. + t = TractogramItem(streamline, data_for_streamline, data_for_points) + assert_equal(len(t), len(streamline)) + assert_array_equal(t.streamline, streamline) + assert_array_equal(list(t), streamline) + assert_array_equal(t.data_for_streamline['mean_curvature'], + mean_curvature) + assert_array_equal(t.data_for_streamline['mean_color'], + mean_color) + assert_array_equal(t.data_for_points['colors'], + colors) + + +class TestTractogram(unittest.TestCase): + + def setUp(self): + self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + self.mean_curvature = np.array([1.11, 2.11, 3.11], dtype="f4") + self.mean_color = np.array([[0, 1, 0], + [0, 0, 1], + [1, 0, 0]], dtype="f4") + + self.nb_streamlines = len(self.streamlines) + + def test_tractogram_creation(self): + # Create an empty tractogram. + tractogram = Tractogram() + assert_equal(len(tractogram), 0) + assert_arrays_equal(tractogram.streamlines, []) + assert_equal(tractogram.data_per_streamline, {}) + assert_equal(tractogram.data_per_point, {}) + assert_true(check_iteration(tractogram)) + + # Create a tractogram with only streamlines + tractogram = Tractogram(streamlines=self.streamlines) + assert_equal(len(tractogram), len(self.streamlines)) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_equal(tractogram.data_per_streamline, {}) + assert_equal(tractogram.data_per_point, {}) + assert_true(check_iteration(tractogram)) + + # Create a tractogram with streamlines and other data. + tractogram = Tractogram( + self.streamlines, + data_per_streamline={'mean_curvature': self.mean_curvature, + 'mean_color': self.mean_color}, + data_per_point={'colors': self.colors}) + + assert_equal(len(tractogram), len(self.streamlines)) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.data_per_streamline['mean_curvature'], + self.mean_curvature) + assert_arrays_equal(tractogram.data_per_streamline['mean_color'], + self.mean_color) + assert_arrays_equal(tractogram.data_per_point['colors'], + self.colors) + + assert_true(check_iteration(tractogram)) + + # Inconsistent number of scalars between streamlines + wrong_data = [[(1, 0, 0)]*1, + [(0, 1, 0), (0, 1)], + [(0, 0, 1)]*5] + + data_per_point = {'wrong_data': wrong_data} + assert_raises(ValueError, Tractogram, self.streamlines, + data_per_point=data_per_point) + + # Inconsistent number of scalars between streamlines + wrong_data = [[(1, 0, 0)]*1, + [(0, 1)]*2, + [(0, 0, 1)]*5] + + data_per_point = {'wrong_data': wrong_data} + assert_raises(ValueError, Tractogram, self.streamlines, + data_per_point=data_per_point) + + def test_tractogram_getitem(self): + # Tractogram with only streamlines + tractogram = Tractogram(streamlines=self.streamlines) + + selected_tractogram = tractogram[::2] + assert_equal(len(selected_tractogram), (len(self.streamlines)+1)//2) + + assert_arrays_equal(selected_tractogram.streamlines, + self.streamlines[::2]) + assert_equal(tractogram.data_per_streamline, {}) + assert_equal(tractogram.data_per_point, {}) + + # Create a tractogram with streamlines and other data. + tractogram = Tractogram( + self.streamlines, + data_per_streamline={'mean_curvature': self.mean_curvature, + 'mean_color': self.mean_color}, + data_per_point={'colors': self.colors}) + + # Retrieve tractogram by their index + for i, t in enumerate(tractogram): + assert_array_equal(t.streamline, tractogram[i].streamline) + assert_array_equal(t.data_for_points['colors'], + tractogram[i].data_for_points['colors']) + + assert_array_equal(t.data_for_streamline['mean_curvature'], + tractogram[i].data_for_streamline['mean_curvature']) + + assert_array_equal(t.data_for_streamline['mean_color'], + tractogram[i].data_for_streamline['mean_color']) + + # Use slicing + r_tractogram = tractogram[::-1] + assert_arrays_equal(r_tractogram.streamlines, self.streamlines[::-1]) + + assert_arrays_equal(r_tractogram.data_per_streamline['mean_curvature'], + self.mean_curvature[::-1]) + assert_arrays_equal(r_tractogram.data_per_streamline['mean_color'], + self.mean_color[::-1]) + assert_arrays_equal(r_tractogram.data_per_point['colors'], + self.colors[::-1]) + + def test_tractogram_add_new_data(self): + # Tractogram with only streamlines + tractogram = Tractogram(streamlines=self.streamlines) + + tractogram.data_per_streamline['mean_curvature'] = self.mean_curvature + tractogram.data_per_streamline['mean_color'] = self.mean_color + tractogram.data_per_point['colors'] = self.colors + + # Retrieve tractogram by their index + for i, t in enumerate(tractogram): + assert_array_equal(t.streamline, tractogram[i].streamline) + assert_array_equal(t.data_for_points['colors'], + tractogram[i].data_for_points['colors']) + + assert_array_equal(t.data_for_streamline['mean_curvature'], + tractogram[i].data_for_streamline['mean_curvature']) + + assert_array_equal(t.data_for_streamline['mean_color'], + tractogram[i].data_for_streamline['mean_color']) + + # Use slicing + r_tractogram = tractogram[::-1] + assert_arrays_equal(r_tractogram.streamlines, self.streamlines[::-1]) + + assert_arrays_equal(r_tractogram.data_per_streamline['mean_curvature'], + self.mean_curvature[::-1]) + assert_arrays_equal(r_tractogram.data_per_streamline['mean_color'], + self.mean_color[::-1]) + assert_arrays_equal(r_tractogram.data_per_point['colors'], + self.colors[::-1]) + + def test_tractogram_copy(self): + # Create a tractogram with streamlines and other data. + tractogram1 = Tractogram( + self.streamlines, + data_per_streamline={'mean_curvature': self.mean_curvature, + 'mean_color': self.mean_color}, + data_per_point={'colors': self.colors}) + + # Create a copy of the tractogram. + tractogram2 = tractogram1.copy() + + # Check we copied the data and not simply created new references. + assert_true(tractogram1 is not tractogram2) + assert_true(tractogram1.streamlines is not tractogram2.streamlines) + assert_true(tractogram1.data_per_streamline + is not tractogram2.data_per_streamline) + assert_true(tractogram1.data_per_streamline['mean_curvature'] + is not tractogram2.data_per_streamline['mean_curvature']) + assert_true(tractogram1.data_per_streamline['mean_color'] + is not tractogram2.data_per_streamline['mean_color']) + assert_true(tractogram1.data_per_point + is not tractogram2.data_per_point) + assert_true(tractogram1.data_per_point['colors'] + is not tractogram2.data_per_point['colors']) + + # Check the data are the equivalent. + assert_true(check_iteration(tractogram2)) + assert_equal(len(tractogram1), len(tractogram2)) + assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) + assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) + assert_arrays_equal(tractogram1.data_per_streamline['mean_curvature'], + tractogram2.data_per_streamline['mean_curvature']) + assert_arrays_equal(tractogram1.data_per_streamline['mean_color'], + tractogram2.data_per_streamline['mean_color']) + assert_arrays_equal(tractogram1.data_per_point['colors'], + tractogram2.data_per_point['colors']) + + +class TestLazyTractogram(unittest.TestCase): + + def setUp(self): + self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + self.mean_curvature = np.array([1.11, 2.11, 3.11], dtype="f4") + self.mean_color = np.array([[0, 1, 0], + [0, 0, 1], + [1, 0, 0]], dtype="f4") + + self.nb_streamlines = len(self.streamlines) + + self.colors_func = lambda: (x for x in self.colors) + self.mean_curvature_func = lambda: (x for x in self.mean_curvature) + self.mean_color_func = lambda: (x for x in self.mean_color) + + def test_lazy_tractogram_creation(self): + # To create tractogram from arrays use `Tractogram`. + assert_raises(TypeError, LazyTractogram, self.streamlines) + + # Streamlines and other data as generators + streamlines = (x for x in self.streamlines) + data_per_point = {"colors": (x for x in self.colors)} + data_per_streamline = {'mean_curv': (x for x in self.mean_curvature), + 'mean_color': (x for x in self.mean_color)} + + # Creating LazyTractogram with generators is not allowed as + # generators get exhausted and are not reusable unlike coroutines. + assert_raises(TypeError, LazyTractogram, streamlines) + assert_raises(TypeError, LazyTractogram, + data_per_streamline=data_per_streamline) + assert_raises(TypeError, LazyTractogram, self.streamlines, + data_per_point=data_per_point) + + # Empty `LazyTractogram` + tractogram = LazyTractogram() + assert_true(check_iteration(tractogram)) + assert_equal(len(tractogram), 0) + assert_arrays_equal(tractogram.streamlines, []) + assert_equal(tractogram.data_per_point, {}) + assert_equal(tractogram.data_per_streamline, {}) + + # Create tractogram with streamlines and other data + streamlines = lambda: (x for x in self.streamlines) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} + + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + + assert_true(check_iteration(tractogram)) + assert_equal(len(tractogram), self.nb_streamlines) + + # Coroutines get re-called and creates new iterators. + for i in range(2): + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.data_per_streamline['mean_curv'], + self.mean_curvature) + assert_arrays_equal(tractogram.data_per_streamline['mean_color'], + self.mean_color) + assert_arrays_equal(tractogram.data_per_point['colors'], + self.colors) + + def test_lazy_tractogram_create_from(self): + # Create `LazyTractogram` from a coroutine yielding nothing (i.e empty). + _empty_data_gen = lambda: iter([]) + + tractogram = LazyTractogram.create_from(_empty_data_gen) + assert_true(check_iteration(tractogram)) + assert_equal(len(tractogram), 0) + assert_arrays_equal(tractogram.streamlines, []) + assert_equal(tractogram.data_per_point, {}) + assert_equal(tractogram.data_per_streamline, {}) + + # Create `LazyTractogram` from a coroutine yielding TractogramItem + def _data_gen(): + for d in zip(self.streamlines, self.colors, + self.mean_curvature, self.mean_color): + data_for_points = {'colors': d[1]} + data_for_streamline = {'mean_curv': d[2], + 'mean_color': d[3]} + yield TractogramItem(d[0], data_for_streamline, data_for_points) + + tractogram = LazyTractogram.create_from(_data_gen) + assert_true(check_iteration(tractogram)) + assert_equal(len(tractogram), self.nb_streamlines) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.data_per_streamline['mean_curv'], + self.mean_curvature) + assert_arrays_equal(tractogram.data_per_streamline['mean_color'], + self.mean_color) + assert_arrays_equal(tractogram.data_per_point['colors'], + self.colors) + + # Creating a LazyTractogram from not a corouting should raise an error. + assert_raises(TypeError, LazyTractogram.create_from, _data_gen()) + + def test_lazy_tractogram_getitem(self): + streamlines = lambda: (x for x in self.streamlines) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} + + # By default, `LazyTractogram` object does not support indexing. + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + assert_raises(AttributeError, tractogram.__getitem__, 0) + + def test_lazy_tractogram_len(self): + streamlines = lambda: (x for x in self.streamlines) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} + + modules = [module_tractogram] # Modules for which to catch warnings. + with clear_and_catch_warnings(record=True, modules=modules) as w: + warnings.simplefilter("always") # Always trigger warnings. + + # Calling `len` will create new generators each time. + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + assert_true(tractogram._nb_streamlines is None) + + # This should produce a warning message. + assert_equal(len(tractogram), self.nb_streamlines) + assert_equal(tractogram._nb_streamlines, self.nb_streamlines) + assert_equal(len(w), 1) + + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + + # New instances should still produce a warning message. + assert_equal(len(tractogram), self.nb_streamlines) + assert_equal(len(w), 2) + assert_true(issubclass(w[-1].category, UsageWarning)) + + # Calling again 'len' again should *not* produce a warning. + assert_equal(len(tractogram), self.nb_streamlines) + assert_equal(len(w), 2) + + with clear_and_catch_warnings(record=True, modules=modules) as w: + # Once we iterated through the tractogram, we know the length. + + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + + assert_true(tractogram._nb_streamlines is None) + check_iteration(tractogram) # Force iteration through tractogram. + assert_equal(tractogram._nb_streamlines, len(self.streamlines)) + # This should *not* produce a warning. + assert_equal(len(tractogram), len(self.streamlines)) + assert_equal(len(w), 0) + + def test_lazy_tractogram_apply_affine(self): + streamlines = lambda: (x for x in self.streamlines) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} + + affine = np.eye(4) + scaling = np.array((1, 2, 3), dtype=float) + affine[range(3), range(3)] = scaling + + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + + tractogram.apply_affine(affine) + assert_true(check_iteration(tractogram)) + assert_equal(len(tractogram), len(self.streamlines)) + for s1, s2 in zip(tractogram.streamlines, self.streamlines): + assert_array_almost_equal(s1, s2*scaling) + + def test_lazy_tractogram_copy(self): + # Create tractogram with streamlines and other data + streamlines = lambda: (x for x in self.streamlines) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} + + tractogram1 = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + assert_true(check_iteration(tractogram1)) # Implicitly set _nb_streamlines. + + # Create a copy of the tractogram. + tractogram2 = tractogram1.copy() + + # Check we copied the data and not simply created new references. + assert_true(tractogram1 is not tractogram2) + + # When copying LazyTractogram, coroutines generating streamlines should + # be the same. + assert_true(tractogram1._streamlines is tractogram2._streamlines) + + # Copying LazyTractogram, creates new internal LazyDict objects, + # but coroutines contained in it should be the same. + assert_true(tractogram1._data_per_streamline + is not tractogram2._data_per_streamline) + assert_true(tractogram1.data_per_streamline.store['mean_curv'] + is tractogram2.data_per_streamline.store['mean_curv']) + assert_true(tractogram1.data_per_streamline.store['mean_color'] + is tractogram2.data_per_streamline.store['mean_color']) + assert_true(tractogram1._data_per_point + is not tractogram2._data_per_point) + assert_true(tractogram1.data_per_point.store['colors'] + is tractogram2.data_per_point.store['colors']) + + # The affine should be a copy. + assert_true(tractogram1._affine_to_apply + is not tractogram2._affine_to_apply) + assert_array_equal(tractogram1._affine_to_apply, + tractogram2._affine_to_apply) + + # Check the data are the equivalent. + assert_equal(tractogram1._nb_streamlines, tractogram2._nb_streamlines) + assert_true(check_iteration(tractogram2)) + assert_equal(len(tractogram1), len(tractogram2)) + assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) + assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) + assert_arrays_equal(tractogram1.data_per_streamline['mean_curv'], + tractogram2.data_per_streamline['mean_curv']) + assert_arrays_equal(tractogram1.data_per_streamline['mean_color'], + tractogram2.data_per_streamline['mean_color']) + assert_arrays_equal(tractogram1.data_per_point['colors'], + tractogram2.data_per_point['colors']) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py new file mode 100644 index 0000000000..f689e53ef8 --- /dev/null +++ b/nibabel/streamlines/tests/test_trk.py @@ -0,0 +1,387 @@ +import os +import unittest +import numpy as np + +from nibabel.externals.six import BytesIO + +from nibabel.testing import suppress_warnings, clear_and_catch_warnings +from nibabel.testing import assert_arrays_equal, check_iteration +from nose.tools import assert_equal, assert_raises, assert_true + +from .test_tractogram import assert_tractogram_equal +from ..tractogram import Tractogram, LazyTractogram +from ..tractogram_file import DataError, HeaderError, HeaderWarning + +from .. import trk as trk_module +from ..trk import TrkFile, header_2_dtype + +DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') + + +def assert_header_equal(h1, h2): + header1 = np.zeros(1, dtype=header_2_dtype) + header2 = np.zeros(1, dtype=header_2_dtype) + + for k, v in h1.items(): + header1[k] = v + + for k, v in h2.items(): + header2[k] = v + + assert_equal(header1, header2) + + +class TestTRK(unittest.TestCase): + + def setUp(self): + self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") + # simple.trk contains only streamlines + self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") + # standard.trk contains only streamlines + self.standard_trk_filename = os.path.join(DATA_PATH, "standard.trk") + # standard.LPS.trk contains only streamlines + self.standard_LPS_trk_filename = os.path.join(DATA_PATH, "standard.LPS.trk") + + # complex.trk contains streamlines, scalars and properties + self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") + + self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + self.fa = [np.array([[0.2]], dtype="f4"), + np.array([[0.3], + [0.4]], dtype="f4"), + np.array([[0.5], + [0.6], + [0.6], + [0.7], + [0.8]], dtype="f4")] + + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + self.mean_curvature = [np.array([1.11], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11], dtype="f4")] + + self.mean_torsion = [np.array([1.22], dtype="f4"), + np.array([2.22], dtype="f4"), + np.array([3.22], dtype="f4")] + + self.mean_colors = [np.array([1, 0, 0], dtype="f4"), + np.array([0, 1, 0], dtype="f4"), + np.array([0, 0, 1], dtype="f4")] + + self.data_per_point = {'colors': self.colors, + 'fa': self.fa} + self.data_per_streamline = {'mean_curvature': self.mean_curvature, + 'mean_torsion': self.mean_torsion, + 'mean_colors': self.mean_colors} + + self.empty_tractogram = Tractogram() + self.simple_tractogram = Tractogram(self.streamlines) + self.complex_tractogram = Tractogram(self.streamlines, + self.data_per_streamline, + self.data_per_point) + + def test_load_empty_file(self): + for lazy_load in [False, True]: + trk = TrkFile.load(self.empty_trk_filename, lazy_load=lazy_load) + assert_tractogram_equal(trk.tractogram, self.empty_tractogram) + + def test_load_simple_file(self): + for lazy_load in [False, True]: + trk = TrkFile.load(self.simple_trk_filename, lazy_load=lazy_load) + assert_tractogram_equal(trk.tractogram, self.simple_tractogram) + + def test_load_complex_file(self): + for lazy_load in [False, True]: + trk = TrkFile.load(self.complex_trk_filename, lazy_load=lazy_load) + assert_tractogram_equal(trk.tractogram, self.complex_tractogram) + + def test_load_file_with_wrong_information(self): + trk_file = open(self.simple_trk_filename, 'rb').read() + + # Simulate a TRK file where `count` was not provided. + count = np.array(0, dtype="int32").tostring() + new_trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] + trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) + assert_tractogram_equal(trk.tractogram, self.simple_tractogram) + + # Simulate a TRK file where `voxel_order` was not provided. + voxel_order = np.zeros(1, dtype="|S3").tostring() + new_trk_file = trk_file[:948] + voxel_order + trk_file[948+3:] + with clear_and_catch_warnings(record=True, modules=[trk_module]) as w: + TrkFile.load(BytesIO(new_trk_file)) + assert_equal(len(w), 1) + assert_true(issubclass(w[0].category, HeaderWarning)) + assert_true("LPS" in str(w[0].message)) + + # Simulate a TRK file with an unsupported version. + version = np.int32(123).tostring() + new_trk_file = trk_file[:992] + version + trk_file[992+4:] + assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + + # Simulate a TRK file with a wrong hdr_size. + hdr_size = np.int32(1234).tostring() + new_trk_file = trk_file[:996] + hdr_size + trk_file[996+4:] + assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + + def test_write_empty_file(self): + tractogram = Tractogram() + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + new_trk_orig = TrkFile.load(self.empty_trk_filename) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) + + trk_file.seek(0, os.SEEK_SET) + assert_equal(trk_file.read(), open(self.empty_trk_filename, 'rb').read()) + + def test_write_simple_file(self): + tractogram = Tractogram(self.streamlines) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + new_trk_orig = TrkFile.load(self.simple_trk_filename) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) + + trk_file.seek(0, os.SEEK_SET) + assert_equal(trk_file.read(), open(self.simple_trk_filename, 'rb').read()) + + def test_write_complex_file(self): + # With scalars + tractogram = Tractogram(self.streamlines, + data_per_point=self.data_per_point) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + # With properties + tractogram = Tractogram(self.streamlines, + data_per_streamline=self.data_per_streamline) + + trk = TrkFile(tractogram) + trk_file = BytesIO() + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + # With scalars and properties + tractogram = Tractogram(self.streamlines, + data_per_point=self.data_per_point, + data_per_streamline=self.data_per_streamline) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + new_trk_orig = TrkFile.load(self.complex_trk_filename) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) + + trk_file.seek(0, os.SEEK_SET) + assert_equal(trk_file.read(), + open(self.complex_trk_filename, 'rb').read()) + + def test_write_erroneous_file(self): + # No scalars for every points + scalars = [[(1, 0, 0)], + [(0, 1, 0)], + [(0, 0, 1)]] + + tractogram = Tractogram(self.streamlines, + data_per_point={'scalars': scalars}) + trk = TrkFile(tractogram) + assert_raises(DataError, trk.save, BytesIO()) + + # No scalars for every streamlines + scalars = [[(1, 0, 0)]*1, + [(0, 1, 0)]*2] + + tractogram = Tractogram(self.streamlines, + data_per_point={'scalars': scalars}) + trk = TrkFile(tractogram) + assert_raises(IndexError, trk.save, BytesIO()) + + # Inconsistent number of properties + properties = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + tractogram = Tractogram(self.streamlines, + data_per_streamline={'properties': properties}) + trk = TrkFile(tractogram) + assert_raises(DataError, trk.save, BytesIO()) + + # No properties for every streamlines + properties = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11, 2.22], dtype="f4")] + tractogram = Tractogram(self.streamlines, + data_per_streamline={'properties': properties}) + trk = TrkFile(tractogram) + assert_raises(IndexError, trk.save, BytesIO()) + + def test_load_write_file(self): + for filename in [self.empty_trk_filename, self.simple_trk_filename, self.complex_trk_filename]: + for lazy_load in [False, True]: + trk = TrkFile.load(filename, lazy_load=lazy_load) + trk_file = BytesIO() + trk.save(trk_file) + + new_trk = TrkFile.load(filename, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, trk.tractogram) + + trk_file.seek(0, os.SEEK_SET) + #assert_equal(trk_file.read(), open(filename, 'rb').read()) + + def test_load_write_LPS_file(self): + # Load the RAS and LPS version of the standard. + trk_RAS = TrkFile.load(self.standard_trk_filename, lazy_load=False) + trk_LPS = TrkFile.load(self.standard_LPS_trk_filename, lazy_load=False) + assert_tractogram_equal(trk_LPS.tractogram, trk_RAS.tractogram) + + # Write back the standard. + trk_file = BytesIO() + trk = TrkFile(trk_LPS.tractogram, trk_LPS.header) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file) + + assert_header_equal(new_trk.header, trk.header) + assert_tractogram_equal(new_trk.tractogram, trk.tractogram) + + new_trk_orig = TrkFile.load(self.standard_LPS_trk_filename) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) + + trk_file.seek(0, os.SEEK_SET) + assert_equal(trk_file.read(), + open(self.standard_LPS_trk_filename, 'rb').read()) + + def test_write_too_many_scalars_and_properties(self): + # TRK supports up to 10 data_per_point. + data_per_point = {} + for i in range(10): + data_per_point['#{0}'.format(i)] = self.fa + + tractogram = Tractogram(self.streamlines, + data_per_point=data_per_point) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + # More than 10 data_per_point should raise an error. + data_per_point['#{0}'.format(i+1)] = self.fa + + tractogram = Tractogram(self.streamlines, + data_per_point=data_per_point) + + trk = TrkFile(tractogram) + assert_raises(ValueError, trk.save, BytesIO()) + + # TRK supports up to 10 data_per_streamline. + data_per_streamline = {} + for i in range(10): + data_per_streamline['#{0}'.format(i)] = self.mean_torsion + + tractogram = Tractogram(self.streamlines, + data_per_streamline=data_per_streamline) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + # More than 10 data_per_streamline should raise an error. + data_per_streamline['#{0}'.format(i+1)] = self.mean_torsion + + tractogram = Tractogram(self.streamlines, + data_per_streamline=data_per_streamline) + + trk = TrkFile(tractogram) + assert_raises(ValueError, trk.save, BytesIO()) + + def test_write_scalars_and_properties_name_too_long(self): + # TRK supports data_per_point name up to 20 characters. + # However, we reserve the last two characters to store + # the number of values associated to each data_per_point. + # So in reality we allow name of 18 characters, otherwise + # the name is truncated and warning is issue. + for nb_chars in range(22): + data_per_point = {'A'*nb_chars: self.colors} + tractogram = Tractogram(self.streamlines, + data_per_point=data_per_point) + + trk = TrkFile(tractogram) + if nb_chars > 18: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) + + data_per_point = {'A'*nb_chars: self.fa} + tractogram = Tractogram(self.streamlines, + data_per_point=data_per_point) + + trk = TrkFile(tractogram) + if nb_chars > 20: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) + + # TRK supports data_per_streamline name up to 20 characters. + # However, we reserve the last two characters to store + # the number of values associated to each data_per_streamline. + # So in reality we allow name of 18 characters, otherwise + # the name is truncated and warning is issue. + for nb_chars in range(22): + data_per_streamline = {'A'*nb_chars: self.mean_colors} + tractogram = Tractogram(self.streamlines, + data_per_streamline=data_per_streamline) + + trk = TrkFile(tractogram) + if nb_chars > 18: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) + + data_per_streamline = {'A'*nb_chars: self.mean_torsion} + tractogram = Tractogram(self.streamlines, + data_per_streamline=data_per_streamline) + + trk = TrkFile(tractogram) + if nb_chars > 20: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) diff --git a/nibabel/streamlines/tests/test_utils.py b/nibabel/streamlines/tests/test_utils.py new file mode 100644 index 0000000000..6c3bf096a6 --- /dev/null +++ b/nibabel/streamlines/tests/test_utils.py @@ -0,0 +1,34 @@ +import os +import numpy as np +import nibabel as nib + +from nibabel.testing import data_path +from numpy.testing import assert_array_equal +from nose.tools import assert_equal, assert_raises, assert_true + +from ..utils import pop, get_affine_from_reference + + +def test_peek(): + gen = (i for i in range(3)) + assert_equal(pop(gen), 0) + assert_equal(pop(gen), 1) + assert_equal(pop(gen), 2) + assert_true(pop(gen) is None) + + +def test_get_affine_from_reference(): + filename = os.path.join(data_path, 'example_nifti2.nii.gz') + img = nib.load(filename) + affine = img.affine + + # Get affine from an numpy array. + assert_array_equal(get_affine_from_reference(affine), affine) + wrong_ref = np.array([[1, 2, 3], [4, 5, 6]]) + assert_raises(ValueError, get_affine_from_reference, wrong_ref) + + # Get affine from a `SpatialImage`. + assert_array_equal(get_affine_from_reference(img), affine) + + # Get affine from a `SpatialImage` using by its filename. + assert_array_equal(get_affine_from_reference(filename), affine) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py new file mode 100644 index 0000000000..29fb975eee --- /dev/null +++ b/nibabel/streamlines/tractogram.py @@ -0,0 +1,398 @@ +import numpy as np +import collections +from warnings import warn + +from nibabel.affines import apply_affine + +from .compact_list import CompactList + + +class UsageWarning(Warning): + pass + + +class TractogramItem(object): + """ Class containing information about one streamline. + + ``TractogramItem`` objects have three main properties: `streamline`, + `data_for_streamline`, and `data_for_points`. + + Parameters + ---------- + streamline : ndarray of shape (N, 3) + Points of this streamline represented as an ndarray of shape (N, 3) + where N is the number of points. + + data_for_streamline : dict + + data_for_points : dict + """ + def __init__(self, streamline, data_for_streamline, data_for_points): + self.streamline = np.asarray(streamline) + self.data_for_streamline = data_for_streamline + self.data_for_points = data_for_points + + def __iter__(self): + return iter(self.streamline) + + def __len__(self): + return len(self.streamline) + + +class Tractogram(object): + """ Class containing information about streamlines. + + Tractogram objects have three main properties: ``streamlines`` + + Parameters + ---------- + streamlines : list of ndarray of shape (Nt, 3) + Sequence of T streamlines. One streamline is an ndarray of shape + (Nt, 3) where Nt is the number of points of streamline t. + + data_per_streamline : dictionary of list of ndarray of shape (P,) + Sequence of T ndarrays of shape (P,) where T is the number of + streamlines defined by ``streamlines``, P is the number of properties + associated to each streamline. + + data_per_point : dictionary of list of ndarray of shape (Nt, M) + Sequence of T ndarrays of shape (Nt, M) where T is the number of + streamlines defined by ``streamlines``, Nt is the number of points + for a particular streamline t and M is the number of scalars + associated to each point (excluding the three coordinates). + + """ + def __init__(self, streamlines=None, + data_per_streamline=None, + data_per_point=None): + + self.streamlines = streamlines + self.data_per_streamline = data_per_streamline + self.data_per_point = data_per_point + + @property + def streamlines(self): + return self._streamlines + + @streamlines.setter + def streamlines(self, value): + self._streamlines = CompactList(value) + + @property + def data_per_streamline(self): + return self._data_per_streamline + + @data_per_streamline.setter + def data_per_streamline(self, value): + if value is None: + value = {} + + self._data_per_streamline = {} + for k, v in value.items(): + self._data_per_streamline[k] = np.asarray(v) + + @property + def data_per_point(self): + return self._data_per_point + + @data_per_point.setter + def data_per_point(self, value): + if value is None: + value = {} + + self._data_per_point = {} + for k, v in value.items(): + self._data_per_point[k] = CompactList(v) + + def __iter__(self): + for i in range(len(self.streamlines)): + yield self[i] + + def __getitem__(self, idx): + pts = self.streamlines[idx] + + data_per_streamline = {} + for key in self.data_per_streamline: + data_per_streamline[key] = self.data_per_streamline[key][idx] + + data_per_point = {} + for key in self.data_per_point: + data_per_point[key] = self.data_per_point[key][idx] + + if isinstance(idx, int) or isinstance(idx, np.integer): + return TractogramItem(pts, data_per_streamline, data_per_point) + + return Tractogram(pts, data_per_streamline, data_per_point) + + def __len__(self): + return len(self.streamlines) + + def copy(self): + """ Returns a copy of this `Tractogram` object. """ + data_per_streamline = {} + for key in self.data_per_streamline: + data_per_streamline[key] = self.data_per_streamline[key].copy() + + data_per_point = {} + for key in self.data_per_point: + data_per_point[key] = self.data_per_point[key].copy() + + tractogram = Tractogram(self.streamlines.copy(), + data_per_streamline, + data_per_point) + return tractogram + + def apply_affine(self, affine): + """ Applies an affine transformation on the points of each streamline. + + This is performed in-place. + + Parameters + ---------- + affine : 2D array (4,4) + Transformation that will be applied on each streamline. + """ + if len(self.streamlines) == 0: + return + + BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. + for i in range(0, len(self.streamlines._data), BUFFER_SIZE): + pts = self.streamlines._data[i:i+BUFFER_SIZE] + self.streamlines._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) + + +class LazyTractogram(Tractogram): + ''' Class containing information about streamlines. + + Tractogram objects have four main properties: ``header``, ``streamlines``, + ``scalars`` and ``properties``. Tractogram objects are iterable and + produce tuple of ``streamlines``, ``scalars`` and ``properties`` for each + streamline. + + Parameters + ---------- + streamlines_func : coroutine ouputting (Nt,3) array-like (optional) + Function yielding streamlines. One streamline is + an ndarray of shape (Nt,3) where Nt is the number of points of + streamline t. + + scalars_func : coroutine ouputting (Nt,M) array-like (optional) + Function yielding scalars for a particular streamline t. The scalars + are represented as an ndarray of shape (Nt,M) where Nt is the number + of points of that streamline t and M is the number of scalars + associated to each point (excluding the three coordinates). + + properties_func : coroutine ouputting (P,) array-like (optional) + Function yielding properties for a particular streamline t. The + properties are represented as an ndarray of shape (P,) where P is + the number of properties associated to each streamline. + + getitem_func : function `idx -> 3-tuples` (optional) + Function returning a subset of the tractogram given an index or a + slice (i.e. the __getitem__ function to use). + + Notes + ----- + If provided, ``scalars`` and ``properties`` must yield the same number of + values as ``streamlines``. + ''' + + class LazyDict(collections.MutableMapping): + """ Internal dictionary with lazy evaluations. """ + + def __init__(self, *args, **kwargs): + self.store = dict() + + # Use update to set keys. + if len(args) == 1 and isinstance(args[0], LazyTractogram.LazyDict): + self.update(dict(args[0].store.items())) + else: + self.update(dict(*args, **kwargs)) + + def __getitem__(self, key): + return self.store[key]() + + def __setitem__(self, key, value): + if value is not None and not callable(value): + raise TypeError("`value` must be a coroutine or None.") + + self.store[key] = value + + def __delitem__(self, key): + del self.store[key] + + def __iter__(self): + return iter(self.store) + + def __len__(self): + return len(self.store) + + def __init__(self, streamlines=None, data_per_streamline=None, + data_per_point=None): + super(LazyTractogram, self).__init__(streamlines, data_per_streamline, + data_per_point) + self._nb_streamlines = None + self._data = None + self._affine_to_apply = np.eye(4) + + @classmethod + def create_from(cls, data_func): + ''' Creates a `LazyTractogram` from a coroutine yielding + `TractogramItem` objects. + + Parameters + ---------- + data_func : coroutine yielding `TractogramItem` objects + A function that whenever it is called starts yielding + `TractogramItem` objects that should be part of this + LazyTractogram. + + ''' + if not callable(data_func): + raise TypeError("`data_func` must be a coroutine.") + + lazy_tractogram = cls() + lazy_tractogram._data = data_func + + try: + first_item = next(data_func()) + + # Set data_per_streamline using data_func + def _gen(key): + return lambda: (t.data_for_streamline[key] for t in data_func()) + + data_per_streamline_keys = first_item.data_for_streamline.keys() + for k in data_per_streamline_keys: + lazy_tractogram._data_per_streamline[k] = _gen(k) + + # Set data_per_point using data_func + def _gen(key): + return lambda: (t.data_for_points[key] for t in data_func()) + + data_per_point_keys = first_item.data_for_points.keys() + for k in data_per_point_keys: + lazy_tractogram._data_per_point[k] = _gen(k) + + except StopIteration: + pass + + return lazy_tractogram + + @property + def streamlines(self): + streamlines_gen = iter([]) + if self._streamlines is not None: + streamlines_gen = self._streamlines() + elif self._data is not None: + streamlines_gen = (t.streamline for t in self._data()) + + # Check if we need to apply an affine. + if not np.all(self._affine_to_apply == np.eye(4)): + def _apply_affine(): + for s in streamlines_gen: + yield apply_affine(self._affine_to_apply, s) + + return _apply_affine() + + return streamlines_gen + + @streamlines.setter + def streamlines(self, value): + if value is not None and not callable(value): + raise TypeError("`streamlines` must be a coroutine.") + + self._streamlines = value + + @property + def data_per_streamline(self): + return self._data_per_streamline + + @data_per_streamline.setter + def data_per_streamline(self, value): + if value is None: + value = {} + + self._data_per_streamline = LazyTractogram.LazyDict(value) + + @property + def data_per_point(self): + return self._data_per_point + + @data_per_point.setter + def data_per_point(self, value): + if value is None: + value = {} + + self._data_per_point = LazyTractogram.LazyDict(value) + + @property + def data(self): + if self._data is not None: + return self._data() + + def _gen_data(): + data_per_streamline_generators = {} + for k, v in self.data_per_streamline.items(): + data_per_streamline_generators[k] = iter(v) + + data_per_point_generators = {} + for k, v in self.data_per_point.items(): + data_per_point_generators[k] = iter(v) + + for s in self.streamlines: + data_for_streamline = {} + for k, v in data_per_streamline_generators.items(): + data_for_streamline[k] = next(v) + + data_for_points = {} + for k, v in data_per_point_generators.items(): + data_for_points[k] = next(v) + + yield TractogramItem(s, data_for_streamline, data_for_points) + + return _gen_data() + + def __getitem__(self, idx): + raise AttributeError('`LazyTractogram` does not support indexing.') + + def __iter__(self): + i = 0 + for i, tractogram_item in enumerate(self.data, start=1): + yield tractogram_item + + # Keep how many streamlines there are in this tractogram. + self._nb_streamlines = i + + def __len__(self): + # Check if we know how many streamlines there are. + if self._nb_streamlines is None: + warn("Number of streamlines will be determined manually by looping" + " through the streamlines. If you know the actual number of" + " streamlines, you might want to set it beforehand via" + " `self.header.nb_streamlines`." + " Note this will consume any generators used to create this" + " `LazyTractogram` object.", UsageWarning) + # Count the number of streamlines. + self._nb_streamlines = sum(1 for _ in self.streamlines) + + return self._nb_streamlines + + def copy(self): + """ Returns a copy of this `LazyTractogram` object. """ + tractogram = LazyTractogram(self._streamlines, + self._data_per_streamline, + self._data_per_point) + tractogram._nb_streamlines = self._nb_streamlines + tractogram._data = self._data + tractogram._affine_to_apply = self._affine_to_apply.copy() + return tractogram + + def apply_affine(self, affine): + """ Applies an affine transformation on the streamlines. + + Parameters + ---------- + affine : 2D array (4,4) + Transformation that will be applied on each streamline. + """ + self._affine_to_apply = np.dot(affine, self._affine_to_apply) diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py new file mode 100644 index 0000000000..8bb1aa41ea --- /dev/null +++ b/nibabel/streamlines/tractogram_file.py @@ -0,0 +1,117 @@ +from abc import ABCMeta, abstractmethod, abstractproperty + +from .header import TractogramHeader + + +class HeaderWarning(Warning): + pass + + +class HeaderError(Exception): + pass + + +class DataError(Exception): + pass + + +class abstractclassmethod(classmethod): + __isabstractmethod__ = True + + def __init__(self, callable): + callable.__isabstractmethod__ = True + super(abstractclassmethod, self).__init__(callable) + + +class TractogramFile(object): + ''' Convenience class to encapsulate tractogram file format. ''' + __metaclass__ = ABCMeta + + def __init__(self, tractogram, header=None): + self._tractogram = tractogram + self._header = {} if header is None else header + #self._header = TractogramHeader() if header is None else header + #self._header = TractogramHeader(header) + + @property + def tractogram(self): + return self._tractogram + + @property + def streamlines(self): + return self.tractogram.streamlines + + @property + def header(self): + return self._header + + def get_tractogram(self): + return self.tractogram + + def get_streamlines(self): + return self.streamlines + + def get_header(self): + return self.header + + @classmethod + def get_magic_number(cls): + ''' Returns streamlines file's magic number. ''' + raise NotImplementedError() + + @classmethod + def support_data_per_point(cls): + ''' Tells if this tractogram format supports saving data per point. ''' + raise NotImplementedError() + + @classmethod + def support_data_per_streamline(cls): + ''' Tells if this tractogram format supports saving data per streamline. ''' + raise NotImplementedError() + + @classmethod + def is_correct_format(cls, fileobj): + ''' Checks if the file has the right streamlines file format. + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header). + Returns + ------- + is_correct_format : boolean + Returns True if `fileobj` is in the right streamlines file format. + ''' + raise NotImplementedError() + + @abstractclassmethod + def load(cls, fileobj, lazy_load=True): + ''' Loads streamlines from a file-like object. + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header). + lazy_load : boolean (optional) + Load streamlines in a lazy manner i.e. they will not be kept + in memory. For postprocessing speed, turn off this option. + Returns + ------- + tractogram_file : ``TractogramFile`` object + Returns an object containing tractogram data and header + information. + ''' + raise NotImplementedError() + + @abstractmethod + def save(self, fileobj): + ''' Saves streamlines to a file-like object. + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + opened and ready to write. + ''' + raise NotImplementedError() diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py new file mode 100644 index 0000000000..9b043ff821 --- /dev/null +++ b/nibabel/streamlines/trk.py @@ -0,0 +1,739 @@ +from __future__ import division + +# Documentation available here: +# http://www.trackvis.org/docs/?subsect=fileformat + +import os +import struct +import warnings +import itertools + +import numpy as np +import nibabel as nib + +from nibabel.affines import apply_affine +from nibabel.openers import Opener +from nibabel.py3k import asbytes, asstr +from nibabel.volumeutils import (native_code, swapped_code) + +from .compact_list import CompactList +from .tractogram_file import TractogramFile +from .tractogram_file import DataError, HeaderError, HeaderWarning +from .tractogram import TractogramItem, Tractogram, LazyTractogram +from .header import Field + + +MAX_NB_NAMED_SCALARS_PER_POINT = 10 +MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE = 10 + +# Definition of trackvis header structure. +# See http://www.trackvis.org/docs/?subsect=fileformat +# See http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html +header_1_dtd = [(Field.MAGIC_NUMBER, 'S6'), + (Field.DIMENSIONS, 'h', 3), + (Field.VOXEL_SIZES, 'f4', 3), + (Field.ORIGIN, 'f4', 3), + (Field.NB_SCALARS_PER_POINT, 'h'), + ('scalar_name', 'S20', MAX_NB_NAMED_SCALARS_PER_POINT), + (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), + ('property_name', 'S20', MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), + ('reserved', 'S508'), + (Field.VOXEL_ORDER, 'S4'), + ('pad2', 'S4'), + ('image_orientation_patient', 'f4', 6), + ('pad1', 'S2'), + ('invert_x', 'S1'), + ('invert_y', 'S1'), + ('invert_z', 'S1'), + ('swap_xy', 'S1'), + ('swap_yz', 'S1'), + ('swap_zx', 'S1'), + (Field.NB_STREAMLINES, 'i4'), + ('version', 'i4'), + ('hdr_size', 'i4'), + ] + +# Version 2 adds a 4x4 matrix giving the affine transformtation going +# from voxel coordinates in the referenced 3D voxel matrix, to xyz +# coordinates (axes L->R, P->A, I->S). If (0 based) value [3, 3] from +# this matrix is 0, this means the matrix is not recorded. +header_2_dtd = [(Field.MAGIC_NUMBER, 'S6'), + (Field.DIMENSIONS, 'h', 3), + (Field.VOXEL_SIZES, 'f4', 3), + (Field.ORIGIN, 'f4', 3), + (Field.NB_SCALARS_PER_POINT, 'h'), + ('scalar_name', 'S20', MAX_NB_NAMED_SCALARS_PER_POINT), + (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), + ('property_name', 'S20', MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), + (Field.VOXEL_TO_RASMM, 'f4', (4, 4)), # new field for version 2 + ('reserved', 'S444'), + (Field.VOXEL_ORDER, 'S4'), + ('pad2', 'S4'), + ('image_orientation_patient', 'f4', 6), + ('pad1', 'S2'), + ('invert_x', 'S1'), + ('invert_y', 'S1'), + ('invert_z', 'S1'), + ('swap_xy', 'S1'), + ('swap_yz', 'S1'), + ('swap_zx', 'S1'), + (Field.NB_STREAMLINES, 'i4'), + ('version', 'i4'), + ('hdr_size', 'i4'), + ] + +# Full header numpy dtypes +header_1_dtype = np.dtype(header_1_dtd) +header_2_dtype = np.dtype(header_2_dtd) + + +class TrkReader(object): + ''' Convenience class to encapsulate TRK file format. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header) + + Note + ---- + TrackVis (so its file format: TRK) considers the streamline coordinate + (0,0,0) to be in the corner of the voxel whereas NiBabel's streamlines + internal representation (Voxel space) assume (0,0,0) to be in the + center of the voxel. + + Thus, streamlines are shifted of half a voxel on load and are shifted + back on save. + ''' + def __init__(self, fileobj): + self.fileobj = fileobj + + with Opener(self.fileobj) as f: + # Read header + header_str = f.read(header_2_dtype.itemsize) + header_rec = np.fromstring(string=header_str, dtype=header_2_dtype) + + if header_rec['version'] == 1: + header_rec = np.fromstring(string=header_str, dtype=header_1_dtype) + elif header_rec['version'] == 2: + pass # Nothing more to do + else: + raise HeaderError('NiBabel only supports versions 1 and 2.') + + # Convert the first record of `header_rec` into a dictionnary + self.header = dict(zip(header_rec.dtype.names, header_rec[0])) + + # Check endianness + self.endianness = native_code + if self.header['hdr_size'] != TrkFile.HEADER_SIZE: + self.endianness = swapped_code + + # Swap byte order + self.header = dict(zip(header_rec.dtype.names, header_rec[0].newbyteorder())) + if self.header['hdr_size'] != TrkFile.HEADER_SIZE: + raise HeaderError('Invalid hdr_size: {0} instead of {1}'.format(self.header['hdr_size'], TrkFile.HEADER_SIZE)) + + # By default, the voxel order is LPS. + # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates + if self.header[Field.VOXEL_ORDER] == b"": + warnings.warn(("Voxel order is not specified, will assume" + " 'LPS' since it is Trackvis software's" + " default."), HeaderWarning) + self.header[Field.VOXEL_ORDER] = b"LPS" + + # Keep the file position where the data begin. + self.offset_data = f.tell() + + def __iter__(self): + i4_dtype = np.dtype(self.endianness + "i4") + f4_dtype = np.dtype(self.endianness + "f4") + + with Opener(self.fileobj) as f: + start_position = f.tell() + + nb_pts_and_scalars = int(3 + self.header[Field.NB_SCALARS_PER_POINT]) + pts_and_scalars_size = int(nb_pts_and_scalars * f4_dtype.itemsize) + properties_size = int(self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize) + + # Set the file position at the beginning of the data. + f.seek(self.offset_data, os.SEEK_SET) + + # If 'count' field is 0, i.e. not provided, we have to loop until the EOF. + nb_streamlines = self.header[Field.NB_STREAMLINES] + if nb_streamlines == 0: + nb_streamlines = np.inf + + i = 0 + nb_pts_dtype = i4_dtype.str[:-1] + while i < nb_streamlines: + nb_pts_str = f.read(i4_dtype.itemsize) + + # Check if we reached EOF + if len(nb_pts_str) == 0: + break + + # Read number of points of the next streamline. + nb_pts = struct.unpack(nb_pts_dtype, nb_pts_str)[0] + + # Read streamline's data + points_and_scalars = np.ndarray(shape=(nb_pts, nb_pts_and_scalars), + dtype=f4_dtype, + buffer=f.read(nb_pts * pts_and_scalars_size)) + + points = points_and_scalars[:, :3] + scalars = points_and_scalars[:, 3:] + + # Read properties + properties = np.ndarray(shape=(self.header[Field.NB_PROPERTIES_PER_STREAMLINE],), + dtype=f4_dtype, + buffer=f.read(properties_size)) + + yield points, scalars, properties + i += 1 + + # In case the 'count' field was not provided. + self.header[Field.NB_STREAMLINES] = i + + # Set the file position where it was (in case it was already open). + f.seek(start_position, os.SEEK_CUR) + + +class TrkWriter(object): + @classmethod + def create_empty_header(cls): + ''' Return an empty compliant TRK header. ''' + header = np.zeros(1, dtype=header_2_dtype) + + #Default values + header[Field.MAGIC_NUMBER] = TrkFile.MAGIC_NUMBER + header[Field.VOXEL_SIZES] = (1, 1, 1) + header[Field.DIMENSIONS] = (1, 1, 1) + header[Field.VOXEL_TO_RASMM] = np.eye(4) + header[Field.VOXEL_ORDER] = b"RAS" + header['version'] = 2 + header['hdr_size'] = TrkFile.HEADER_SIZE + + return header + + def __init__(self, fileobj, header): + self.header = self.create_empty_header() + + # Override hdr's fields by those contained in `header`. + for k, v in header.items(): + if k in header_2_dtype.fields.keys(): + self.header[k] = v + + # By default, the voxel order is LPS. + # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates + if self.header[Field.VOXEL_ORDER] == b"": + self.header[Field.VOXEL_ORDER] = b"LPS" + + # Keep counts for correcting incoherent fields or warn. + self.nb_streamlines = 0 + self.nb_points = 0 + self.nb_scalars = 0 + self.nb_properties = 0 + + # Write header + self.header = self.header[0] + self.file = Opener(fileobj, mode="wb") + # Keep track of the beginning of the header. + self.beginning = self.file.tell() + self.file.write(self.header.tostring()) + + def write(self, tractogram): + i4_dtype = np.dtype("i4") + f4_dtype = np.dtype("f4") + + try: + first_item = next(iter(tractogram)) + except StopIteration: + # Empty tractogram + self.header[Field.NB_STREAMLINES] = 0 + self.header[Field.NB_SCALARS_PER_POINT] = 0 + self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = 0 + # Overwrite header with updated one. + self.file.seek(self.beginning, os.SEEK_SET) + self.file.write(self.header.tostring()) + return + + # Update the 'property_name' field using 'data_per_streamline' of the tractogram. + data_for_streamline = first_item.data_for_streamline + if len(data_for_streamline) > MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: + raise ValueError("Can only store {0} named data_per_streamline (properties).".format(MAX_NB_NAMED_SCALARS_PER_POINT)) + + data_for_streamline_keys = sorted(data_for_streamline.keys()) + self.header['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') + for i, k in enumerate(data_for_streamline_keys): + nb_values = data_for_streamline[k].shape[0] + + if len(k) > 20: + raise ValueError("Property name '{0}' is too long (max 20 char.)".format(k)) + elif len(k) > 18 and nb_values > 1: + raise ValueError("Property name '{0}' is too long (need to be less than 18 characters when storing more than one value".format(k)) + + property_name = k + if nb_values > 1: + # Use the last to bytes of the name to store the nb of values associated to this data_for_streamline. + property_name = asbytes(k[:18].ljust(18, '\x00')) + b'\x00' + np.array(nb_values, dtype=np.int8).tostring() + + self.header['property_name'][i] = property_name + + # Update the 'scalar_name' field using 'data_per_point' of the tractogram. + data_for_points = first_item.data_for_points + if len(data_for_points) > MAX_NB_NAMED_SCALARS_PER_POINT: + raise ValueError("Can only store {0} named data_per_point (scalars).".format(MAX_NB_NAMED_SCALARS_PER_POINT)) + + data_for_points_keys = sorted(data_for_points.keys()) + self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') + for i, k in enumerate(data_for_points_keys): + nb_values = data_for_points[k].shape[1] + + if len(k) > 20: + raise ValueError("Scalar name '{0}' is too long (max 18 char.)".format(k)) + elif len(k) > 18 and nb_values > 1: + raise ValueError("Scalar name '{0}' is too long (need to be less than 18 characters when storing more than one value".format(k)) + + scalar_name = k + if nb_values > 1: + # Use the last to bytes of the name to store the nb of values associated to this data_for_streamline. + scalar_name = asbytes(k[:18].ljust(18, '\x00')) + b'\x00' + np.array(nb_values, dtype=np.int8).tostring() + + self.header['scalar_name'][i] = scalar_name + + # `Tractogram` streamlines are in RAS+ and mm space, we will compute + # the affine matrix that will bring them back to 'voxelmm' as required + # by the TRK format. + affine = np.eye(4) + + # Applied the inverse of the affine found in the TRK header. + # rasmm -> voxel + affine = np.dot(np.linalg.inv(self.header[Field.VOXEL_TO_RASMM]), affine) + + # If the voxel order implied by the affine does not match the voxel + # order in the TRK header, change the orientation. + # voxel (affine) -> voxel (header) + header_ornt = asstr(self.header[Field.VOXEL_ORDER]) + affine_ornt = "".join(nib.orientations.aff2axcodes(self.header[Field.VOXEL_TO_RASMM])) + header_ornt = nib.orientations.axcodes2ornt(header_ornt) + affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) + ornt = nib.orientations.ornt_transform(affine_ornt, header_ornt) + M = nib.orientations.inv_ornt_aff(ornt, self.header[Field.DIMENSIONS]) + affine = np.dot(M, affine) + + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas `Tractogram` streamlines assume (0,0,0) is the + # center of the voxel. Thus, streamlines are shifted of half a voxel. + offset = np.eye(4) + offset[:-1, -1] += 0.5 + affine = np.dot(offset, affine) + + # Finally send the streamlines in mm space. + # voxel -> voxelmm + scale = np.eye(4) + scale[range(3), range(3)] *= self.header[Field.VOXEL_SIZES] + affine = np.dot(scale, affine) + + # The TRK format uses float32 as the data type for points. + affine = affine.astype(np.float32) + + for t in tractogram: + if any((len(d) != len(t.streamline) for d in t.data_for_points.values())): + raise DataError("Missing scalars for some points!") + + points = apply_affine(affine, np.asarray(t.streamline, dtype=f4_dtype)) + scalars = [np.asarray(t.data_for_points[k], dtype=f4_dtype) for k in data_for_points_keys] + scalars = np.concatenate([np.ndarray((len(points), 0), dtype=f4_dtype)] + scalars, axis=1) + properties = [np.asarray(t.data_for_streamline[k], dtype=f4_dtype) for k in data_for_streamline_keys] + properties = np.concatenate([np.array([], dtype=f4_dtype)] + properties) + + data = struct.pack(i4_dtype.str[:-1], len(points)) + data += np.concatenate([points, scalars], axis=1).tostring() + data += properties.tostring() + self.file.write(data) + + self.nb_streamlines += 1 + self.nb_points += len(points) + self.nb_scalars += scalars.size + self.nb_properties += len(properties) + + # Use those values to update the header. + nb_scalars_per_point = self.nb_scalars / self.nb_points + nb_properties_per_streamline = self.nb_properties / self.nb_streamlines + + # Check for errors + if nb_scalars_per_point != int(nb_scalars_per_point): + raise DataError("Nb. of scalars differs from one point to another!") + + if nb_properties_per_streamline != int(nb_properties_per_streamline): + raise DataError("Nb. of properties differs from one streamline to another!") + + self.header[Field.NB_STREAMLINES] = self.nb_streamlines + self.header[Field.NB_SCALARS_PER_POINT] = nb_scalars_per_point + self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = nb_properties_per_streamline + + # Overwrite header with updated one. + self.file.seek(self.beginning, os.SEEK_SET) + self.file.write(self.header.tostring()) + + +def create_compactlist_from_generator(gen): + BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. + + streamlines = CompactList() + scalars = CompactList() + properties = np.array([]) + + gen = iter(gen) + try: + first_element = next(gen) + gen = itertools.chain([first_element], gen) + except StopIteration: + return streamlines, scalars, properties + + # Allocated some buffer memory. + pts = np.asarray(first_element[0]) + scals = np.asarray(first_element[1]) + props = np.asarray(first_element[2]) + + scals_shape = scals.shape + props_shape = props.shape + + streamlines._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) + scalars._data = np.empty((BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) + properties = np.empty((BUFFER_SIZE, props.shape[0]), dtype=props.dtype) + + offset = 0 + for i, (pts, scals, props) in enumerate(gen): + pts = np.asarray(pts) + scals = np.asarray(scals) + props = np.asarray(props) + + if scals.shape[1] != scals_shape[1]: + raise ValueError("Number of scalars differs from one" + " point or streamline to another") + + if props.shape != props_shape: + raise ValueError("Number of properties differs from one" + " streamline to another") + + end = offset + len(pts) + if end >= len(streamlines._data): + # Resize is needed (at least `len(pts)` items will be added). + streamlines._data.resize((len(streamlines._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) + scalars._data.resize((len(scalars._data) + len(scals)+BUFFER_SIZE, scals.shape[1])) + + streamlines._offsets.append(offset) + streamlines._lengths.append(len(pts)) + streamlines._data[offset:offset+len(pts)] = pts + scalars._data[offset:offset+len(scals)] = scals + + offset += len(pts) + + if i >= len(properties): + properties.resize((len(properties) + BUFFER_SIZE, props.shape[0])) + + properties[i] = props + + # Clear unused memory. + streamlines._data.resize((offset, pts.shape[1])) + + if scals_shape[1] == 0: + # Because resizing an empty ndarray creates memory! + scalars._data = np.empty((offset, scals.shape[1])) + else: + scalars._data.resize((offset, scals.shape[1])) + + # Share offsets and lengths between streamlines and scalars. + scalars._offsets = streamlines._offsets + scalars._lengths = streamlines._lengths + + if props_shape[0] == 0: + # Because resizing an empty ndarray creates memory! + properties = np.empty((i+1, props.shape[0])) + else: + properties.resize((i+1, props.shape[0])) + + return streamlines, scalars, properties + + + +class TrkFile(TractogramFile): + ''' Convenience class to encapsulate TRK file format. + + Note + ---- + TrackVis (so its file format: TRK) considers the streamline coordinate + (0,0,0) to be in the corner of the voxel whereas NiBabel's streamlines + internal representation (Voxel space) assume (0,0,0) to be in the + center of the voxel. + + Thus, streamlines are shifted of half a voxel on load and are shifted + back on save. + ''' + + # Contants + MAGIC_NUMBER = b"TRACK" + HEADER_SIZE = 1000 + + def __init__(self, tractogram, header=None, ref=np.eye(4)): + """ + Parameters + ---------- + tractogram : ``Tractogram`` object + Tractogram that will be contained in this ``TrkFile``. + + header : ``TractogramHeader`` file (optional) + Metadata associated to this tractogram file. + + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines live in. + + Notes + ----- + Streamlines of the tractogram are assumed to be in *RAS+* and *mm* space + where coordinate (0,0,0) refers to the center of the voxel. + """ + if header is None: + header_rec = TrkWriter.create_empty_header() + header = dict(zip(header_rec.dtype.names, header_rec)) + + super(TrkFile, self).__init__(tractogram, header) + #self._affine = get_affine_from_reference(ref) + + @classmethod + def get_magic_number(cls): + ''' Return TRK's magic number. ''' + return cls.MAGIC_NUMBER + + @classmethod + def support_data_per_point(cls): + ''' Tells if this tractogram format supports saving data per point. ''' + return True + + @classmethod + def support_data_per_streamline(cls): + ''' Tells if this tractogram format supports saving data per streamline. ''' + return True + + @classmethod + def is_correct_format(cls, fileobj): + ''' Check if the file is in TRK format. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data). + + Returns + ------- + is_correct_format : boolean + Returns True if `fileobj` is in TRK format. + ''' + with Opener(fileobj) as f: + magic_number = f.read(5) + f.seek(-5, os.SEEK_CUR) + return magic_number == cls.MAGIC_NUMBER + + return False + + @classmethod + def load(cls, fileobj, lazy_load=False): + ''' Loads streamlines from a file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header). + + lazy_load : boolean (optional) + Load streamlines in a lazy manner i.e. they will not be kept + in memory. + + Returns + ------- + trk_file : ``TrkFile`` object + Returns an object containing tractogram data and header + information. + + Notes + ----- + Streamlines of the returned tractogram are assumed to be in RASmm + space where coordinate (0,0,0) refers to the center of the voxel. + ''' + trk_reader = TrkReader(fileobj) + + # TRK's streamlines are in 'voxelmm' space, we will compute the + # affine matrix that will bring them back to RAS+ and mm space. + affine = np.eye(4) + + # The affine matrix found in the TRK header requires the points to be + # in the voxel space. + # voxelmm -> voxel + scale = np.eye(4) + scale[range(3), range(3)] /= trk_reader.header[Field.VOXEL_SIZES] + affine = np.dot(scale, affine) + + # TrackVis considers coordinate (0,0,0) to be the corner of the voxel + # whereas streamlines returned assume (0,0,0) to be the center of the + # voxel. Thus, streamlines are shifted of half a voxel. + offset = np.eye(4) + offset[:-1, -1] -= 0.5 + affine = np.dot(offset, affine) + + # If the voxel order implied by the affine does not match the voxel + # order in the TRK header, change the orientation. + # voxel (header) -> voxel (affine) + header_ornt = asstr(trk_reader.header[Field.VOXEL_ORDER]) + affine_ornt = "".join(nib.orientations.aff2axcodes(trk_reader.header[Field.VOXEL_TO_RASMM])) + header_ornt = nib.orientations.axcodes2ornt(header_ornt) + affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) + ornt = nib.orientations.ornt_transform(header_ornt, affine_ornt) + M = nib.orientations.inv_ornt_aff(ornt, trk_reader.header[Field.DIMENSIONS]) + affine = np.dot(M, affine) + + # Applied the affine found in the TRK header. + # voxel -> rasmm + affine = np.dot(trk_reader.header[Field.VOXEL_TO_RASMM], affine) + + # Find scalars and properties name + data_per_point_slice = {} + if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: + cpt = 0 + for scalar_name in trk_reader.header['scalar_name']: + scalar_name = asstr(scalar_name) + if len(scalar_name) == 0: + continue + + # Check if we encoded the number of values we stocked for this scalar name. + nb_scalars = 1 + if scalar_name[-2] == '\x00' and scalar_name[-1] != '\x00': + nb_scalars = int(np.fromstring(scalar_name[-1], np.int8)) + + scalar_name = scalar_name.split('\x00')[0] + data_per_point_slice[scalar_name] = slice(cpt, cpt+nb_scalars) + cpt += nb_scalars + + if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: + data_per_point_slice['scalars'] = slice(cpt, trk_reader.header[Field.NB_SCALARS_PER_POINT]) + + data_per_streamline_slice = {} + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + cpt = 0 + for property_name in trk_reader.header['property_name']: + property_name = asstr(property_name) + if len(property_name) == 0: + continue + + # Check if we encoded the number of values we stocked for this property name. + nb_properties = 1 + if property_name[-2] == '\x00' and property_name[-1] != '\x00': + nb_properties = int(np.fromstring(property_name[-1], np.int8)) + + property_name = property_name.split('\x00')[0] + data_per_streamline_slice[property_name] = slice(cpt, cpt+nb_properties) + cpt += nb_properties + + if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: + data_per_streamline_slice['properties'] = slice(cpt, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]) + + + if lazy_load: + def _read(): + for pts, scals, props in trk_reader: + data_for_points = dict((k, scals[:, v]) for k, v in data_per_point_slice.items()) + data_for_streamline = dict((k, props[v]) for k, v in data_per_streamline_slice.items()) + yield TractogramItem(pts, data_for_streamline, data_for_points) + + tractogram = LazyTractogram.create_from(_read) + + else: + streamlines, scalars, properties = create_compactlist_from_generator(trk_reader) + tractogram = Tractogram(streamlines) + + for scalar_name, slice_ in data_per_point_slice.items(): + clist = CompactList() + clist._data = scalars._data[:, slice_] + clist._offsets = scalars._offsets + clist._lengths = scalars._lengths + tractogram.data_per_point[scalar_name] = clist + + for property_name, slice_ in data_per_streamline_slice.items(): + tractogram.data_per_streamline[property_name] = properties[:, slice_] + + # Bring tractogram to RAS+ and mm space + tractogram.apply_affine(affine.astype(np.float32)) + + ## Perform some integrity checks + #if tractogram.header.voxel_sizes != trk_reader.header[Field.VOXEL_SIZES]: + # raise HeaderError("'voxel_sizes' does not match the affine.") + #if tractogram.header.nb_scalars_per_point != trk_reader.header[Field.NB_SCALARS_PER_POINT]: + # raise HeaderError("'nb_scalars_per_point' does not match.") + #if tractogram.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: + # raise HeaderError("'nb_properties_per_streamline' does not match.") + + return cls(tractogram, header=trk_reader.header, ref=affine) + + def save(self, fileobj): + ''' Saves tractogram to a file-like object using TRK format. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data). + ''' + trk_writer = TrkWriter(fileobj, self.header) + trk_writer.write(self.tractogram) + + def __str__(self): + ''' Gets a formatted string of the header of a TRK file. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the header). + + Returns + ------- + info : string + Header information relevant to the TRK format. + ''' + #trk_reader = TrkReader(fileobj) + hdr = self.header + + info = "" + info += "\nMAGIC NUMBER: {0}".format(hdr[Field.MAGIC_NUMBER]) + info += "\nv.{0}".format(hdr['version']) + info += "\ndim: {0}".format(hdr[Field.DIMENSIONS]) + info += "\nvoxel_sizes: {0}".format(hdr[Field.VOXEL_SIZES]) + info += "\norgin: {0}".format(hdr[Field.ORIGIN]) + info += "\nnb_scalars: {0}".format(hdr[Field.NB_SCALARS_PER_POINT]) + info += "\nscalar_name:\n {0}".format("\n".join(hdr['scalar_name'])) + info += "\nnb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + info += "\nproperty_name:\n {0}".format("\n".join(hdr['property_name'])) + info += "\nvox_to_world: {0}".format(hdr[Field.VOXEL_TO_RASMM]) + info += "\nvoxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) + info += "\nimage_orientation_patient: {0}".format(hdr['image_orientation_patient']) + info += "\npad1: {0}".format(hdr['pad1']) + info += "\npad2: {0}".format(hdr['pad2']) + info += "\ninvert_x: {0}".format(hdr['invert_x']) + info += "\ninvert_y: {0}".format(hdr['invert_y']) + info += "\ninvert_z: {0}".format(hdr['invert_z']) + info += "\nswap_xy: {0}".format(hdr['swap_xy']) + info += "\nswap_yz: {0}".format(hdr['swap_yz']) + info += "\nswap_zx: {0}".format(hdr['swap_zx']) + info += "\nn_count: {0}".format(hdr[Field.NB_STREAMLINES]) + info += "\nhdr_size: {0}".format(hdr['hdr_size']) + #info += "endianess: {0}".format(hdr[Field.ENDIAN]) + + return info diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py new file mode 100644 index 0000000000..7bbbe1ef8d --- /dev/null +++ b/nibabel/streamlines/utils.py @@ -0,0 +1,35 @@ +import numpy as np +import nibabel +import itertools + +from nibabel.spatialimages import SpatialImage + + +def get_affine_from_reference(ref): + """ Returns the affine defining the reference space. + + Parameter + --------- + ref : filename | `Nifti1Image` object | 2D array (4,4) + Reference space where streamlines live in `fileobj`. + + Returns + ------- + affine : 2D array (4,4) + """ + if type(ref) is np.ndarray: + if ref.shape != (4, 4): + raise ValueError("`ref` needs to be a numpy array with shape (4,4)!") + + return ref + elif isinstance(ref, SpatialImage): + return ref.affine + + # Assume `ref` is the name of a neuroimaging file. + return nibabel.load(ref).affine + + +def pop(iterable): + "Returns the next item from the iterable else None" + value = list(itertools.islice(iterable, 1)) + return value[0] if len(value) > 0 else None diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index 0da16744d8..8fdbc74f08 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -11,9 +11,12 @@ import sys import warnings +import collections from os.path import dirname, abspath, join as pjoin +from nibabel.externals.six.moves import zip_longest import numpy as np +from numpy.testing import assert_array_equal # Allow failed import of nose if not now running tests try: @@ -57,6 +60,22 @@ def assert_allclose_safely(a, b, match_nans=True): assert_true(np.allclose(a, b)) +def check_iteration(iterable): + """ Checks that an object can be iterated through without errors. """ + try: + for _ in iterable: + pass + except: + return False + + return True + + +def assert_arrays_equal(arrays1, arrays2): + for arr1, arr2 in zip_longest(arrays1, arrays2, fillvalue=None): + assert_array_equal(arr1, arr2) + + def get_fresh_mod(mod_name=__name__): # Get this module, with warning registry empty my_mod = sys.modules[mod_name] diff --git a/setup.py b/setup.py index db0b69866b..9ebaf7747b 100755 --- a/setup.py +++ b/setup.py @@ -89,6 +89,7 @@ def main(**extra_args): 'nibabel.testing', 'nibabel.tests', 'nibabel.benchmarks', + 'nibabel.streamlines', # install nisext as its own package 'nisext', 'nisext.tests'], @@ -103,6 +104,7 @@ def main(**extra_args): pjoin('externals', 'tests', 'data', '*'), pjoin('nicom', 'tests', 'data', '*'), pjoin('gifti', 'tests', 'data', '*'), + pjoin('streamlines', 'tests', 'data', '*'), ]}, scripts = [pjoin('bin', 'parrec2nii'), pjoin('bin', 'nib-ls'),