From e82c3eecff6c3c50d5425a3f4baa2f597c474c65 Mon Sep 17 00:00:00 2001 From: Marc-Alexandre Cote Date: Thu, 20 Feb 2014 19:40:09 -0500 Subject: [PATCH 01/54] First draft --- nibabel/streamlines/__init__.py | 24 ++ nibabel/streamlines/base_format.py | 39 +++ nibabel/streamlines/header.py | 15 + nibabel/streamlines/tests/__init__.py | 0 nibabel/streamlines/tests/test_trk.py | 56 ++++ nibabel/streamlines/trk.py | 464 ++++++++++++++++++++++++++ nibabel/streamlines/utils.py | 47 +++ 7 files changed, 645 insertions(+) create mode 100644 nibabel/streamlines/__init__.py create mode 100644 nibabel/streamlines/base_format.py create mode 100644 nibabel/streamlines/header.py create mode 100644 nibabel/streamlines/tests/__init__.py create mode 100644 nibabel/streamlines/tests/test_trk.py create mode 100644 nibabel/streamlines/trk.py create mode 100644 nibabel/streamlines/utils.py diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py new file mode 100644 index 0000000000..67a9ec56ff --- /dev/null +++ b/nibabel/streamlines/__init__.py @@ -0,0 +1,24 @@ +from nibabel.openers import Opener + +from nibabel.streamlines.utils import detect_format + + +def load(fileobj): + ''' Load a file of streamlines, return instance associated to file format + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the beginning + of the streamlines file's header) + + Returns + ------- + obj : instance of ``StreamlineFile`` + Returns an instance of a ``StreamlineFile`` subclass corresponding to + the format of the streamlines file ``fileobj``. + ''' + fileobj = Opener(fileobj) + streamlines_file = detect_format(fileobj) + return streamlines_file.load(fileobj) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py new file mode 100644 index 0000000000..eb10baa428 --- /dev/null +++ b/nibabel/streamlines/base_format.py @@ -0,0 +1,39 @@ + +class StreamlineFile: + @staticmethod + def get_magic_number(): + raise NotImplementedError() + + @staticmethod + def is_correct_format(cls, fileobj): + raise NotImplementedError() + + def get_header(self): + raise NotImplementedError() + + def get_streamlines(self, as_generator=False): + raise NotImplementedError() + + def get_scalars(self, as_generator=False): + raise NotImplementedError() + + def get_properties(self, as_generator=False): + raise NotImplementedError() + + @classmethod + def load(cls, fileobj): + raise NotImplementedError() + + def save(self, filename): + raise NotImplementedError() + + def __iter__(self): + raise NotImplementedError() + + +class DynamicStreamlineFile(StreamlineFile): + def append(self, streamlines): + raise NotImplementedError() + + def __iadd__(self, streamlines): + return self.append(streamlines) \ No newline at end of file diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py new file mode 100644 index 0000000000..b12c42d8d4 --- /dev/null +++ b/nibabel/streamlines/header.py @@ -0,0 +1,15 @@ +class Field: + NB_STREAMLINES = "nb_streamlines" + STEP_SIZE = "step_size" + METHOD = "method" + NB_SCALARS_PER_POINT = "nb_scalars_per_point" + NB_PROPERTIES_PER_STREAMLINE = "nb_properties_per_streamline" + NB_POINTS = "nb_points" + VOXEL_SIZES = "voxel_sizes" + DIMENSIONS = "dimensions" + MAGIC_NUMBER = "magic_number" + ORIGIN = "origin" + VOXEL_TO_WORLD = "voxel_to_world" + VOXEL_ORDER = "voxel_order" + WORLD_ORDER = "world_order" + ENDIAN = "endian" \ No newline at end of file diff --git a/nibabel/streamlines/tests/__init__.py b/nibabel/streamlines/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py new file mode 100644 index 0000000000..0f01b03dfe --- /dev/null +++ b/nibabel/streamlines/tests/test_trk.py @@ -0,0 +1,56 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: + +from pdb import set_trace as dbg + +from os.path import join as pjoin, dirname + +from numpy.testing import (assert_equal, + assert_almost_equal, + assert_array_equal, + assert_array_almost_equal, + assert_raises) + +DATA_PATH = pjoin(dirname(__file__), 'data') + +import nibabel as nib +from nibabel.streamlines.header import Field + + +def test_load_file(): + # Test loading empty file + # empty_file = pjoin(DATA_PATH, "empty.trk") + # empty_trk = nib.streamlines.load(empty_file) + + # hdr = empty_trk.get_header() + # points = empty_trk.get_points(as_generator=False) + # scalars = empty_trk.get_scalars(as_generator=False) + # properties = empty_trk.get_properties(as_generator=False) + + # assert_equal(hdr[Field.NB_STREAMLINES], 0) + # assert_equal(len(points), 0) + # assert_equal(len(scalars), 0) + # assert_equal(len(properties), 0) + + # for i in empty_trk: pass # Check if we can iterate through the streamlines. + + # Test loading non-empty file + trk_file = pjoin(DATA_PATH, "uncinate.trk") + trk = nib.streamlines.load(trk_file) + + hdr = trk.get_header() + points = trk.get_points(as_generator=False) + 1/0 + scalars = trk.get_scalars(as_generator=False) + properties = trk.get_properties(as_generator=False) + + assert_equal(hdr[Field.NB_STREAMLINES] > 0, True) + assert_equal(len(points) > 0, True) + #assert_equal(len(scalars), 0) + #assert_equal(len(properties), 0) + + for i in trk: pass # Check if we can iterate through the streamlines. + + +if __name__ == "__main__": + test_load_file() \ No newline at end of file diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py new file mode 100644 index 0000000000..6a02397cf6 --- /dev/null +++ b/nibabel/streamlines/trk.py @@ -0,0 +1,464 @@ +# Documentation available here: +# http://www.trackvis.org/docs/?subsect=fileformat +from pdb import set_trace as dbg + +import os +import warnings +import numpy as np +from numpy.lib.recfunctions import append_fields + +from nibabel.openers import Opener +from nibabel.volumeutils import (native_code, swapped_code, endian_codes) + +from nibabel.streamlines.base_format import DynamicStreamlineFile +from nibabel.streamlines.header import Field + +# Definition of trackvis header structure. +# See http://www.trackvis.org/docs/?subsect=fileformat +# See http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html +header_1_dtd = [ + (Field.MAGIC_NUMBER, 'S6'), + (Field.DIMENSIONS, 'h', 3), + (Field.VOXEL_SIZES, 'f4', 3), + (Field.ORIGIN, 'f4', 3), + (Field.NB_SCALARS_PER_POINT, 'h'), + ('scalar_name', 'S20', 10), + (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), + ('property_name', 'S20', 10), + ('reserved', 'S508'), + (Field.VOXEL_ORDER, 'S4'), + ('pad2', 'S4'), + ('image_orientation_patient', 'f4', 6), + ('pad1', 'S2'), + ('invert_x', 'S1'), + ('invert_y', 'S1'), + ('invert_z', 'S1'), + ('swap_xy', 'S1'), + ('swap_yz', 'S1'), + ('swap_zx', 'S1'), + (Field.NB_STREAMLINES, 'i4'), + ('version', 'i4'), + ('hdr_size', 'i4'), + ] + +# Version 2 adds a 4x4 matrix giving the affine transformtation going +# from voxel coordinates in the referenced 3D voxel matrix, to xyz +# coordinates (axes L->R, P->A, I->S). IF (0 based) value [3, 3] from +# this matrix is 0, this means the matrix is not recorded. +header_2_dtd = [ + (Field.MAGIC_NUMBER, 'S6'), + (Field.DIMENSIONS, 'h', 3), + (Field.VOXEL_SIZES, 'f4', 3), + (Field.ORIGIN, 'f4', 3), + (Field.NB_SCALARS_PER_POINT, 'h'), + ('scalar_name', 'S20', 10), + (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), + ('property_name', 'S20', 10), + (Field.VOXEL_TO_WORLD, 'f4', (4,4)), # new field for version 2 + ('reserved', 'S444'), + (Field.VOXEL_ORDER, 'S4'), + ('pad2', 'S4'), + ('image_orientation_patient', 'f4', 6), + ('pad1', 'S2'), + ('invert_x', 'S1'), + ('invert_y', 'S1'), + ('invert_z', 'S1'), + ('swap_xy', 'S1'), + ('swap_yz', 'S1'), + ('swap_zx', 'S1'), + (Field.NB_STREAMLINES, 'i4'), + ('version', 'i4'), + ('hdr_size', 'i4'), + ] + +# Full header numpy dtypes +header_1_dtype = np.dtype(header_1_dtd) +header_2_dtype = np.dtype(header_2_dtd) + + +class HeaderError(Exception): + pass + + +class DataError(Exception): + pass + + +class TrkFile(DynamicStreamlineFile): + MAGIC_NUMBER = "TRACK" + OFFSET = 1000 + + def __init__(self, hdr, streamlines, scalars, properties): + self.filename = None + + self.hdr = hdr + self.streamlines = streamlines + self.scalars = scalars + self.properties = properties + + ##### + # Static Methods + ### + @classmethod + def get_magic_number(cls): + ''' Return TRK's magic number ''' + return cls.MAGIC_NUMBER + + @classmethod + def is_correct_format(cls, fileobj): + ''' Check if the file is in TRK format. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data) + + Returns + ------- + is_correct_format : boolean + Returns True if `fileobj` is in TRK format, False otherwise. + ''' + with Opener(fileobj) as fileobj: + magic_number = fileobj.read(5) + fileobj.seek(-5, os.SEEK_CUR) + return magic_number == cls.MAGIC_NUMBER + + return False + + @classmethod + def load(cls, fileobj): + hdr = {} + pos_header = 0 + pos_data = 0 + + with Opener(fileobj) as fileobj: + pos_header = fileobj.tell() + + ##### + # Read header + ### + hdr_str = fileobj.read(header_2_dtype.itemsize) + hdr = np.fromstring(string=hdr_str, dtype=header_2_dtype) + + if hdr['version'] == 1: + hdr = np.fromstring(string=hdr_str, dtype=header_1_dtype) + elif hdr['version'] == 2: + pass # Nothing more to do here + else: + raise HeaderError('NiBabel only supports versions 1 and 2.') + + # Make header a dictionnary instead of ndarray + hdr = dict(zip(hdr.dtype.names, hdr[0])) + + # Check endianness + #hdr = append_fields(hdr, Field.ENDIAN, [native_code], usemask=False) + hdr[Field.ENDIAN] = native_code + if hdr['hdr_size'] != 1000: + hdr[Field.ENDIAN] = swapped_code + hdr = hdr.newbyteorder() + if hdr['hdr_size'] != 1000: + raise HeaderError('Invalid hdr_size of {0}'.format(hdr['hdr_size'])) + + # Add more header fields implied by trk format. + #hdr = append_fields(hdr, Field.WORLD_ORDER, ["RAS"], usemask=False) + hdr[Field.WORLD_ORDER] = "RAS" + + pos_data = fileobj.tell() + + i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") + f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") + + nb_streamlines = 0 + + #Either verify the number of streamlines specified in the header is correct or + # count the actual number of streamlines in case it was not specified in the header. + while True: + # Read number of points of the streamline + buf = fileobj.read(i4_dtype.itemsize) + + if buf == '': + break # EOF + + nb_pts = np.fromstring(buf, + dtype=i4_dtype, + count=1) + + bytes_to_skip = nb_pts * 3 # x, y, z coordinates + bytes_to_skip += nb_pts * hdr[Field.NB_SCALARS_PER_POINT] + bytes_to_skip += hdr[Field.NB_PROPERTIES_PER_STREAMLINE] + + # Seek to the next streamline in the file. + fileobj.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) + + nb_streamlines += 1 + + if hdr[Field.NB_STREAMLINES] != nb_streamlines: + warnings.warn('The number of streamlines specified in header ({1}) does not match ' + + 'the actual number of streamlines contained in this file ({1}). ' + + 'The latter will be used.'.format(hdr[Field.NB_STREAMLINES], nb_streamlines)) + + hdr[Field.NB_STREAMLINES] = nb_streamlines + + trk_file = cls(hdr, [], [], []) + trk_file.pos_header = pos_header + trk_file.pos_data = pos_data + trk_file.streamlines + + return trk_file + # cls(hdr, streamlines, scalars, properties) + + def get_header(self): + return self.hdr + + def get_points(self, as_generator=False): + self.fileobj.seek(self.pos_data, os.SEEK_SET) + pos = self.pos_data + + i4_dtype = np.dtype(self.hdr[Field.ENDIAN] + "i4") + f4_dtype = np.dtype(self.hdr[Field.ENDIAN] + "f4") + + for i in range(self.hdr[Field.NB_STREAMLINES]): + # Read number of points of the streamline + nb_pts = np.fromstring(self.fileobj.read(i4_dtype.itemsize), + dtype=i4_dtype, + count=1) + + # Read points of the streamline + pts = np.fromstring(self.fileobj.read(nb_pts * 3 * i4_dtype.itemsize), + dtype=[f4_dtype, f4_dtype, f4_dtype], + count=nb_pts) + + pos = self.fileobj.tell() + yield pts + self.fileobj.seek(pos, os.SEEK_SET) + + bytes_to_skip = nb_pts * self.hdr[Field.NB_SCALARS_PER_POINT] + bytes_to_skip += self.hdr[Field.NB_PROPERTIES_PER_STREAMLINE] + + # Seek to the next streamline in the file. + self.fileobj.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) + + ##### + # Methods + ### + + + + +# import os +# import logging +# import numpy as np + +# from tractconverter.formats.header import Header as H + + +# def readBinaryBytes(f, nbBytes, dtype): +# buff = f.read(nbBytes * dtype.itemsize) +# return np.frombuffer(buff, dtype=dtype) + + +# class TRK: +# # self.hdr +# # self.filename +# # self.hdr[H.ENDIAN] +# # self.FIBER_DELIMITER +# # self.END_DELIMITER + +# @staticmethod +# def create(filename, hdr, anatFile=None): +# f = open(filename, 'wb') +# f.write(TRK.MAGIC_NUMBER + "\n") +# f.close() + +# trk = TRK(filename, load=False) +# trk.hdr = hdr +# trk.writeHeader() + +# return trk + +# ##### +# # Methods +# ### +# def __init__(self, filename, anatFile=None, load=True): +# if not TRK._check(filename): +# raise NameError("Not a TRK file.") + +# self.filename = filename +# self.hdr = {} +# if load: +# self._load() + +# def _load(self): +# f = open(self.filename, 'rb') + +# ##### +# # Read header +# ### +# self.hdr[H.MAGIC_NUMBER] = f.read(6) +# self.hdr[H.DIMENSIONS] = np.frombuffer(f.read(6), dtype='i4') +# self.hdr["version"] = self.hdr["version"].astype('>i4') +# self.hdr["hdr_size"] = self.hdr["hdr_size"].astype('>i4') + +# nb_fibers = 0 +# self.hdr[H.NB_POINTS] = 0 + +# #Either verify the number of streamlines specified in the header is correct or +# # count the actual number of streamlines in case it is not specified in the header. +# remainingBytes = os.path.getsize(self.filename) - self.OFFSET +# while remainingBytes > 0: +# # Read points +# nbPoints = readBinaryBytes(f, 1, np.dtype(self.hdr[H.ENDIAN] + "i4"))[0] +# self.hdr[H.NB_POINTS] += nbPoints +# # This seek is used to go to the next points number indication in the file. +# f.seek((nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]) +# + self.hdr[H.NB_PROPERTIES_PER_STREAMLINE]) * 4, 1) # Relative seek +# remainingBytes -= (nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]) +# + self.hdr[H.NB_PROPERTIES_PER_STREAMLINE]) * 4 + 4 +# nb_fibers += 1 + +# if self.hdr[H.NB_FIBERS] != nb_fibers: +# logging.warn('The number of streamlines specified in header ({1}) does not match ' + +# 'the actual number of streamlines contained in this file ({1}). ' + +# 'The latter will be used.'.format(self.hdr[H.NB_FIBERS], nb_fibers)) + +# self.hdr[H.NB_FIBERS] = nb_fibers + +# f.close() + +# def writeHeader(self): +# # Get the voxel size and format it as an array. +# voxel_sizes = np.asarray(self.hdr.get(H.VOXEL_SIZES, (1.0, 1.0, 1.0)), dtype=' 0: +# # Read points +# nbPoints = readBinaryBytes(f, 1, np.dtype(self.hdr[H.ENDIAN] + "i4"))[0] +# ptsAndScalars = readBinaryBytes(f, +# nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]), +# np.dtype(self.hdr[H.ENDIAN] + "f4")) + +# newShape = [-1, 3 + self.hdr[H.NB_SCALARS_PER_POINT]] +# ptsAndScalars = ptsAndScalars.reshape(newShape) + +# pointsWithoutScalars = ptsAndScalars[:, 0:3] +# yield pointsWithoutScalars + +# # For now, we do not process the tract properties, so just skip over them. +# remainingBytes -= nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]) * 4 + 4 +# remainingBytes -= self.hdr[H.NB_PROPERTIES_PER_STREAMLINE] * 4 +# cpt += 1 + +# f.close() + +# def __str__(self): +# text = "" +# text += "MAGIC NUMBER: {0}".format(self.hdr[H.MAGIC_NUMBER]) +# text += "v.{0}".format(self.hdr['version']) +# text += "dim: {0}".format(self.hdr[H.DIMENSIONS]) +# text += "voxel_sizes: {0}".format(self.hdr[H.VOXEL_SIZES]) +# text += "orgin: {0}".format(self.hdr[H.ORIGIN]) +# text += "nb_scalars: {0}".format(self.hdr[H.NB_SCALARS_PER_POINT]) +# text += "scalar_name:\n {0}".format("\n".join(self.hdr['scalar_name'])) +# text += "nb_properties: {0}".format(self.hdr[H.NB_PROPERTIES_PER_STREAMLINE]) +# text += "property_name:\n {0}".format("\n".join(self.hdr['property_name'])) +# text += "vox_to_world: {0}".format(self.hdr[H.VOXEL_TO_WORLD]) +# text += "world_order: {0}".format(self.hdr[H.WORLD_ORDER]) +# text += "voxel_order: {0}".format(self.hdr[H.VOXEL_ORDER]) +# text += "image_orientation_patient: {0}".format(self.hdr['image_orientation_patient']) +# text += "pad1: {0}".format(self.hdr['pad1']) +# text += "pad2: {0}".format(self.hdr['pad2']) +# text += "invert_x: {0}".format(self.hdr['invert_x']) +# text += "invert_y: {0}".format(self.hdr['invert_y']) +# text += "invert_z: {0}".format(self.hdr['invert_z']) +# text += "swap_xy: {0}".format(self.hdr['swap_xy']) +# text += "swap_yz: {0}".format(self.hdr['swap_yz']) +# text += "swap_zx: {0}".format(self.hdr['swap_zx']) +# text += "n_count: {0}".format(self.hdr[H.NB_FIBERS]) +# text += "hdr_size: {0}".format(self.hdr['hdr_size']) +# text += "endianess: {0}".format(self.hdr[H.ENDIAN]) + +# return text diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py new file mode 100644 index 0000000000..511f27fe01 --- /dev/null +++ b/nibabel/streamlines/utils.py @@ -0,0 +1,47 @@ +import os + +import nibabel as nib + +from nibabel.streamlines.trk import TrkFile +#from nibabel.streamlines.tck import TckFile +#from nibabel.streamlines.vtk import VtkFile +#from nibabel.streamlines.fib import FibFile + +# Supported format +FORMATS = {"trk": TrkFile, + #"tck": TckFile, + #"vtk": VtkFile, + #"fib": FibFile, + } + +def is_supported(fileobj): + return detect_format(fileobj) is not None + + +def detect_format(fileobj): + for format in FORMATS.values(): + if format.is_correct_format(fileobj): + return format + + if isinstance(fileobj, basestring): + _, ext = os.path.splitext(fileobj) + return FORMATS.get(ext, None) + + return None + + +def convert(in_fileobj, out_filename): + in_fileobj = nib.streamlines.load(in_fileobj) + out_format = nib.streamlines.guess_format(out_filename) + + hdr = in_fileobj.get_header() + points = in_fileobj.get_points(as_generator=True) + scalars = in_fileobj.get_scalars(as_generator=True) + properties = in_fileobj.get_properties(as_generator=True) + + out_fileobj = out_format(hdr, points, scalars, properties) + out_fileobj.save(out_filename) + + +def change_space(streamline_file, new_point_space): + pass \ No newline at end of file From fbfb504a27b6f77598a6ae88651def7a1a8dfaa0 Mon Sep 17 00:00:00 2001 From: Marc-Alexandre Cote Date: Fri, 18 Jul 2014 22:47:31 -0400 Subject: [PATCH 02/54] A working prototype of the new streamlines API --- nibabel/__init__.py | 5 +- nibabel/benchmarks/bench_streamlines.py | 127 +++ nibabel/externals/six.py | 1 + nibabel/streamlines/__init__.py | 27 +- nibabel/streamlines/base_format.py | 186 ++++- nibabel/streamlines/header.py | 44 +- nibabel/streamlines/tests/data/complex.trk | Bin 0 -> 1228 bytes nibabel/streamlines/tests/data/empty.trk | Bin 0 -> 1000 bytes nibabel/streamlines/tests/data/simple.trk | Bin 0 -> 1108 bytes nibabel/streamlines/tests/test_base_format.py | 117 +++ nibabel/streamlines/tests/test_trk.py | 365 +++++++-- nibabel/streamlines/tests/test_utils.py | 323 ++++++++ nibabel/streamlines/trk.py | 750 +++++++++--------- nibabel/streamlines/utils.py | 126 ++- nibabel/testing/__init__.py | 18 + setup.py | 2 + 16 files changed, 1601 insertions(+), 490 deletions(-) create mode 100644 nibabel/benchmarks/bench_streamlines.py create mode 100644 nibabel/streamlines/tests/data/complex.trk create mode 100644 nibabel/streamlines/tests/data/empty.trk create mode 100644 nibabel/streamlines/tests/data/simple.trk create mode 100644 nibabel/streamlines/tests/test_base_format.py create mode 100644 nibabel/streamlines/tests/test_utils.py diff --git a/nibabel/__init__.py b/nibabel/__init__.py index 2df9a1c534..e0180a380b 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -23,9 +23,9 @@ img3 = nib.load('spm_file.img') data = img1.get_data() - affine = img1.affine + affine = img1.get_affine() - print(img1) + print img1 nib.save(img1, 'my_file_copy.nii.gz') @@ -63,6 +63,7 @@ apply_orientation, aff2axcodes) from .imageclasses import class_map, ext_map from . import trackvis +from .streamlines import Streamlines from . import mriutils # be friendly on systems with ancient numpy -- no tests, but at least diff --git a/nibabel/benchmarks/bench_streamlines.py b/nibabel/benchmarks/bench_streamlines.py new file mode 100644 index 0000000000..3a2e3ab39d --- /dev/null +++ b/nibabel/benchmarks/bench_streamlines.py @@ -0,0 +1,127 @@ +""" Benchmarks for load and save of streamlines + +Run benchmarks with:: + + import nibabel as nib + nib.bench() + +If you have doctests enabled by default in nose (with a noserc file or +environment variable), and you have a numpy version <= 1.6.1, this will also run +the doctests, let's hope they pass. + +Run this benchmark with: + + nosetests -s --match '(?:^|[\\b_\\.//-])[Bb]ench' /path/to/bench_streamlines.py +""" +from __future__ import division, print_function + +import os +import numpy as np + +from nibabel.externals.six import BytesIO +from nibabel.externals.six.moves import zip + +from nibabel.testing import assert_arrays_equal + +from numpy.testing import assert_array_equal +from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines import TrkFile + +import nibabel as nib +import nibabel.trackvis as tv + +from numpy.testing import measure + + +def bench_load_trk(): + NB_STREAMLINES = 1000 + NB_POINTS = 1000 + points = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)] + repeat = 20 + + trk_file = BytesIO() + trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES)) + tv.write(trk_file, trk) + + mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, lazy_load=False)', repeat) + print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) + + mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file)', repeat) + print("Old: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) + print("Speedup of %2f" % (mtime_old/mtime_new)) + + # Points and scalars + scalars = [np.random.rand(NB_POINTS, 10).astype('float32') for i in range(NB_STREAMLINES)] + + trk_file = BytesIO() + trk = list(zip(points, scalars, [None]*NB_STREAMLINES)) + tv.write(trk_file, trk) + + mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, lazy_load=False)', repeat) + print("New: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new)) + + mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file)', repeat) + print("Old: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old)) + print("Speedup of %2f" % (mtime_old/mtime_new)) + + +def bench_save_trk(): + NB_STREAMLINES = 100 + NB_POINTS = 1000 + points = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)] + repeat = 10 + + # Only points + streamlines = Streamlines(points) + trk_file_new = BytesIO() + + mtime_new = measure('trk_file_new.seek(0, os.SEEK_SET); TrkFile.save(streamlines, trk_file_new)', repeat) + print("\nNew: Saved %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) + + trk_file_old = BytesIO() + trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES)) + mtime_old = measure('trk_file_old.seek(0, os.SEEK_SET); tv.write(trk_file_old, trk)', repeat) + print("Old: Saved %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) + print("Speedup of %2f" % (mtime_old/mtime_new)) + + trk_file_new.seek(0, os.SEEK_SET) + trk_file_old.seek(0, os.SEEK_SET) + streams, hdr = tv.read(trk_file_old) + + for pts, A in zip(points, streams): + assert_array_equal(pts, A[0]) + + trk = nib.streamlines.load(trk_file_new, lazy_load=False) + + assert_arrays_equal(points, trk.points) + + # Points and scalars + scalars = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)] + streamlines = Streamlines(points, scalars=scalars) + trk_file_new = BytesIO() + + mtime_new = measure('trk_file_new.seek(0, os.SEEK_SET); TrkFile.save(streamlines, trk_file_new)', repeat) + print("New: Saved %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new)) + + trk_file_old = BytesIO() + trk = list(zip(points, scalars, [None]*NB_STREAMLINES)) + mtime_old = measure('trk_file_old.seek(0, os.SEEK_SET); tv.write(trk_file_old, trk)', repeat) + print("Old: Saved %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old)) + print("Speedup of %2f" % (mtime_old/mtime_new)) + + trk_file_new.seek(0, os.SEEK_SET) + trk_file_old.seek(0, os.SEEK_SET) + streams, hdr = tv.read(trk_file_old) + + for pts, scal, A in zip(points, scalars, streams): + assert_array_equal(pts, A[0]) + assert_array_equal(scal, A[1]) + + trk = nib.streamlines.load(trk_file_new, lazy_load=False) + + assert_arrays_equal(points, trk.points) + assert_arrays_equal(scalars, trk.scalars) + + +if __name__ == '__main__': + bench_save_trk() diff --git a/nibabel/externals/six.py b/nibabel/externals/six.py index eae31454ae..b23166effd 100644 --- a/nibabel/externals/six.py +++ b/nibabel/externals/six.py @@ -143,6 +143,7 @@ class _MovedItems(types.ModuleType): MovedAttribute("StringIO", "StringIO", "io"), MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), + MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), MovedModule("builtins", "__builtin__"), MovedModule("configparser", "ConfigParser"), diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 67a9ec56ff..cd7cd56a21 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,24 +1,9 @@ -from nibabel.openers import Opener -from nibabel.streamlines.utils import detect_format +from nibabel.streamlines.utils import load, save +from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.header import Field -def load(fileobj): - ''' Load a file of streamlines, return instance associated to file format - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the beginning - of the streamlines file's header) - - Returns - ------- - obj : instance of ``StreamlineFile`` - Returns an instance of a ``StreamlineFile`` subclass corresponding to - the format of the streamlines file ``fileobj``. - ''' - fileobj = Opener(fileobj) - streamlines_file = detect_format(fileobj) - return streamlines_file.load(fileobj) +from nibabel.streamlines.trk import TrkFile +#from nibabel.streamlines.trk import TckFile +#from nibabel.streamlines.trk import VtkFile diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index eb10baa428..43e3b8616e 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -1,39 +1,193 @@ -class StreamlineFile: - @staticmethod - def get_magic_number(): - raise NotImplementedError() +from nibabel.streamlines.header import Field - @staticmethod - def is_correct_format(cls, fileobj): - raise NotImplementedError() +from ..externals.six.moves import zip_longest + + +class HeaderError(Exception): + pass + + +class DataError(Exception): + pass + + +class Streamlines(object): + ''' Class containing information about streamlines. + + Streamlines objects have three main properties: ``points``, ``scalars`` + and ``properties``. Streamlines objects can be iterate over producing + tuple of ``points``, ``scalars`` and ``properties`` for each streamline. + + Parameters + ---------- + points : sequence of ndarray of shape (N, 3) + Sequence of T streamlines. One streamline is an ndarray of shape (N, 3) + where N is the number of points in a streamline. + + scalars : sequence of ndarray of shape (N, M) + Sequence of T ndarrays of shape (N, M) where T is the number of + streamlines defined by ``points``, N is the number of points + for a particular streamline and M is the number of scalars + associated to each point (excluding the three coordinates). + + properties : sequence of ndarray of shape (P,) + Sequence of T ndarrays of shape (P,) where T is the number of + streamlines defined by ``points``, P is the number of properties + associated to each streamlines. + + hdr : dict + Header containing meta information about the streamlines. For a list + of common header's fields to use as keys see `nibabel.streamlines.Field`. + ''' + def __init__(self, points=[], scalars=[], properties=[], hdr={}): + self.hdr = hdr + + self.points = points + self.scalars = scalars + self.properties = properties + self.data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) + + try: + self.length = len(points) + except: + if Field.NB_STREAMLINES in hdr: + self.length = hdr[Field.NB_STREAMLINES] + else: + raise HeaderError(("Neither parameter 'points' nor 'hdr' contain information about" + " number of streamlines. Use key '{0}' to set the number of " + "streamlines in 'hdr'.").format(Field.NB_STREAMLINES)) def get_header(self): - raise NotImplementedError() + return self.hdr + + @property + def points(self): + return self._points() + + @points.setter + def points(self, value): + self._points = value if callable(value) else (lambda: value) + + @property + def scalars(self): + return self._scalars() + + @scalars.setter + def scalars(self, value): + self._scalars = value if callable(value) else lambda: value + + @property + def properties(self): + return self._properties() - def get_streamlines(self, as_generator=False): + @properties.setter + def properties(self, value): + self._properties = value if callable(value) else lambda: value + + def __iter__(self): + return self.data() + + def __len__(self): + return self.length + + +class StreamlinesFile: + ''' Convenience class to encapsulate streamlines file format. ''' + + @classmethod + def get_magic_number(cls): + ''' Return streamlines file's magic number. ''' raise NotImplementedError() - def get_scalars(self, as_generator=False): + @classmethod + def is_correct_format(cls, fileobj): + ''' Check if the file has the right streamlines file format. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header) + + Returns + ------- + is_correct_format : boolean + Returns True if `fileobj` is in the right streamlines file format. + ''' raise NotImplementedError() - def get_properties(self, as_generator=False): + @classmethod + def get_empty_header(cls): + ''' Return an empty streamlines file's header. ''' raise NotImplementedError() @classmethod - def load(cls, fileobj): + def load(cls, fileobj, lazy_load=True): + ''' Loads streamlines from a file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header) + + lazy_load : boolean + Load streamlines in a lazy manner i.e. they will not be kept + in memory. For postprocessing speed, turn off this option. + + Returns + ------- + streamlines : Streamlines object + Returns an object containing streamlines' data and header + information. See 'nibabel.Streamlines'. + ''' raise NotImplementedError() - def save(self, filename): + @classmethod + def save(cls, streamlines, fileobj): + ''' Saves streamlines to a file-like object. + + Parameters + ---------- + streamlines : Streamlines object + Object containing streamlines' data and header information. + See 'nibabel.Streamlines'. + + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + opened and ready to write. + ''' raise NotImplementedError() - def __iter__(self): + @staticmethod + def pretty_print(streamlines): + ''' Gets a formatted string contaning header's information + relevant to the streamlines file format. + + Parameters + ---------- + streamlines : Streamlines object + Object containing streamlines' data and header information. + See 'nibabel.Streamlines'. + + Returns + ------- + info : string + Header's information relevant to the streamlines file format. + ''' raise NotImplementedError() -class DynamicStreamlineFile(StreamlineFile): +class DynamicStreamlineFile(StreamlinesFile): + ''' Convenience class to encapsulate streamlines file format + that supports appending streamlines to an existing file. + ''' + def append(self, streamlines): raise NotImplementedError() def __iadd__(self, streamlines): - return self.append(streamlines) \ No newline at end of file + return self.append(streamlines) diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index b12c42d8d4..aa40d3bb97 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -1,4 +1,11 @@ +from nibabel.orientations import aff2axcodes + + class Field: + """ Header fields common to multiple streamlines file formats. + + In IPython, use `nibabel.streamlines.Field??` to list them. + """ NB_STREAMLINES = "nb_streamlines" STEP_SIZE = "step_size" METHOD = "method" @@ -12,4 +19,39 @@ class Field: VOXEL_TO_WORLD = "voxel_to_world" VOXEL_ORDER = "voxel_order" WORLD_ORDER = "world_order" - ENDIAN = "endian" \ No newline at end of file + ENDIAN = "endian" + + +def create_header_from_nifti(img): + ''' Creates a common streamlines' header using a nifti image. + + Based on the information of the nifti image a dictionnary is created + containing the following keys: `Field.ORIGIN`, `Field.DIMENSIONS`, + `Field.VOXEL_SIZES`, `Field.VOXEL_TO_WORLD`, `Field.WORLD_ORDER` + and `Field.VOXEL_ORDER`. + + Parameters + ---------- + img : Nifti1Image object + Image containing information about the anatomy where streamlines + were created. + + Returns + ------- + hdr : dict + Header containing meta information about streamlines extracted + from the anatomy. + ''' + img_header = img.get_header() + affine = img_header.get_best_affine() + + hdr = {} + + hdr[Field.ORIGIN] = affine[:3, -1] + hdr[Field.DIMENSIONS] = img_header.get_data_shape()[:3] + hdr[Field.VOXEL_SIZES] = img_header.get_zooms()[:3] + hdr[Field.VOXEL_TO_WORLD] = affine + hdr[Field.WORLD_ORDER] = "RAS" # Nifti space + hdr[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine)) + + return hdr diff --git a/nibabel/streamlines/tests/data/complex.trk b/nibabel/streamlines/tests/data/complex.trk new file mode 100644 index 0000000000000000000000000000000000000000..2a96a7bd35f6ce69f05907a351466ac838de1045 GIT binary patch literal 1228 zcmWFua&-1)U<5-3h6Z~CW`F}`IBTdgn1;FsklYW7D|A4K7`j=b{GlEKpnMBT@h^Z3 zM!09dX~Y4N&mn@(Dj03LHZ8>)ja@W21g)<0+6@>kgov52590AKz;xaC!mPia=8QD;O77UQAr@i literal 0 HcmV?d00001 diff --git a/nibabel/streamlines/tests/data/empty.trk b/nibabel/streamlines/tests/data/empty.trk new file mode 100644 index 0000000000000000000000000000000000000000..023b3c5905f222fbd76ce9f82ca72be319b427f6 GIT binary patch literal 1000 ycmWFua&-1)Sio=sh#43f>=78q9R-6p1VC|x4k!^rH*1tX972Ez=!qB13=9BLi3y|t literal 0 HcmV?d00001 diff --git a/nibabel/streamlines/tests/data/simple.trk b/nibabel/streamlines/tests/data/simple.trk new file mode 100644 index 0000000000000000000000000000000000000000..dc5eff4adc9cd919e7bfe31a37485ecc51348109 GIT binary patch literal 1108 zcmWFua&-1)U<5-3h6Z~CW*7y7Is`y*g$^hYLpN)bKh#5j8R!8fAbtU4Fv2|pP9qK= zaR`9$85kTKfO#K?7dWs&Wguk%15gYh$G~s^$bSID42}#80zj+)#0Eg@0K@@6oZtum D?XVOc literal 0 HcmV?d00001 diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py new file mode 100644 index 0000000000..25956bcfed --- /dev/null +++ b/nibabel/streamlines/tests/test_base_format.py @@ -0,0 +1,117 @@ +import os +import unittest +import numpy as np + +from nibabel.testing import assert_arrays_equal +from nose.tools import assert_equal, assert_raises + +from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import HeaderError +from nibabel.streamlines.header import Field + +DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') + + +class TestStreamlines(unittest.TestCase): + + def setUp(self): + self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") + self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") + self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") + + self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11, 2.22], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + + def test_streamlines_creation_from_arrays(self): + # Empty + streamlines = Streamlines() + assert_equal(len(streamlines), 0) + + # TODO: Should Streamlines have a default header? It could have + # NB_STREAMLINES, NB_SCALARS_PER_POINT and NB_PROPERTIES_PER_STREAMLINE + # already set. + hdr = streamlines.get_header() + assert_equal(len(hdr), 0) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in streamlines: + pass + + # Only points + streamlines = Streamlines(points=self.points) + assert_equal(len(streamlines), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in streamlines: + pass + + # Only scalars + streamlines = Streamlines(scalars=self.colors) + assert_equal(len(streamlines), 0) + assert_arrays_equal(streamlines.scalars, self.colors) + + # TODO: is it a faulty behavior? + assert_equal(len(list(streamlines)), len(self.colors)) + + # Points, scalars and properties + streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) + assert_equal(len(streamlines), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + assert_arrays_equal(streamlines.scalars, self.colors) + assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in streamlines: + pass + + def test_streamlines_creation_from_generators(self): + # Points, scalars and properties + points = (x for x in self.points) + scalars = (x for x in self.colors) + properties = (x for x in self.mean_curvature_torsion) + + assert_raises(HeaderError, Streamlines, points, scalars, properties) + + hdr = {Field.NB_STREAMLINES: len(self.points)} + streamlines = Streamlines(points, scalars, properties, hdr) + assert_equal(len(streamlines), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + assert_arrays_equal(streamlines.scalars, self.colors) + assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + + # Have been consumed + assert_equal(len(list(streamlines)), 0) + assert_equal(len(list(streamlines.points)), 0) + assert_equal(len(list(streamlines.scalars)), 0) + assert_equal(len(list(streamlines.properties)), 0) + + def test_streamlines_creation_from_functions(self): + # Points, scalars and properties + points = lambda: (x for x in self.points) + scalars = lambda: (x for x in self.colors) + properties = lambda: (x for x in self.mean_curvature_torsion) + + assert_raises(HeaderError, Streamlines, points, scalars, properties) + + hdr = {Field.NB_STREAMLINES: len(self.points)} + streamlines = Streamlines(points, scalars, properties, hdr) + assert_equal(len(streamlines), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + assert_arrays_equal(streamlines.scalars, self.colors) + assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + + # Have been consumed but lambda functions get re-called. + assert_equal(len(list(streamlines)), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + assert_arrays_equal(streamlines.scalars, self.colors) + assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 0f01b03dfe..251be7a61d 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -1,56 +1,323 @@ -# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- -# vi: set ft=python sts=4 ts=4 sw=4 et: +import os +import unittest +import numpy as np -from pdb import set_trace as dbg +from nibabel.externals.six import BytesIO -from os.path import join as pjoin, dirname - -from numpy.testing import (assert_equal, - assert_almost_equal, - assert_array_equal, - assert_array_almost_equal, - assert_raises) - -DATA_PATH = pjoin(dirname(__file__), 'data') +from nibabel.testing import assert_arrays_equal +from nose.tools import assert_equal, assert_raises import nibabel as nib +from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import DataError, HeaderError from nibabel.streamlines.header import Field +from nibabel.streamlines.trk import TrkFile + +DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') + + +class TestTRK(unittest.TestCase): + + def setUp(self): + self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") + self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") + self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") + + self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11, 2.22], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + + def test_load_empty_file(self): + empty_trk = nib.streamlines.load(self.empty_trk_filename, lazy_load=False) + + hdr = empty_trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], 0) + assert_equal(len(empty_trk), 0) + + points = empty_trk.points + assert_equal(len(points), 0) + + scalars = empty_trk.scalars + assert_equal(len(scalars), 0) + + properties = empty_trk.properties + assert_equal(len(properties), 0) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in empty_trk: + pass + + def test_load_simple_file(self): + simple_trk = nib.streamlines.load(self.simple_trk_filename, lazy_load=False) + + hdr = simple_trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(simple_trk), len(self.points)) + + points = simple_trk.points + assert_arrays_equal(points, self.points) + + scalars = simple_trk.scalars + assert_equal(len(scalars), len(self.points)) + + properties = simple_trk.properties + assert_equal(len(properties), len(self.points)) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in simple_trk: + pass + + # Test lazy_load + simple_trk = nib.streamlines.load(self.simple_trk_filename, lazy_load=True) + + hdr = simple_trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(simple_trk), len(self.points)) + + points = simple_trk.points + assert_arrays_equal(points, self.points) + + scalars = simple_trk.scalars + assert_equal(len(list(scalars)), len(self.points)) + + properties = simple_trk.properties + assert_equal(len(list(properties)), len(self.points)) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in simple_trk: + pass + + def test_load_complex_file(self): + complex_trk = nib.streamlines.load(self.complex_trk_filename, lazy_load=False) + + hdr = complex_trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_trk), len(self.points)) + + points = complex_trk.points + assert_arrays_equal(points, self.points) + + scalars = complex_trk.scalars + assert_arrays_equal(scalars, self.colors) + + properties = complex_trk.properties + assert_arrays_equal(properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_trk: + pass + + complex_trk = nib.streamlines.load(self.complex_trk_filename, lazy_load=True) + + hdr = complex_trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_trk), len(self.points)) + + points = complex_trk.points + assert_arrays_equal(points, self.points) + + scalars = complex_trk.scalars + assert_arrays_equal(scalars, self.colors) + + properties = complex_trk.properties + assert_arrays_equal(properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_trk: + pass + + def test_write_simple_file(self): + streamlines = Streamlines(self.points) + + simple_trk_file = BytesIO() + TrkFile.save(streamlines, simple_trk_file) + + simple_trk_file.seek(0, os.SEEK_SET) + + simple_trk = nib.streamlines.load(simple_trk_file, lazy_load=False) + + hdr = simple_trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(simple_trk), len(self.points)) + + points = simple_trk.points + assert_arrays_equal(points, self.points) + + scalars = simple_trk.scalars + assert_equal(len(scalars), len(self.points)) + + properties = simple_trk.properties + assert_equal(len(properties), len(self.points)) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in simple_trk: + pass + + def test_write_complex_file(self): + # With scalars + streamlines = Streamlines(self.points, scalars=self.colors) + + trk_file = BytesIO() + TrkFile.save(streamlines, trk_file) + + trk_file.seek(0, os.SEEK_SET) + + trk = nib.streamlines.load(trk_file, lazy_load=False) + + hdr = trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(trk), len(self.points)) + + points = trk.points + assert_arrays_equal(points, self.points) + + scalars = trk.scalars + assert_arrays_equal(scalars, self.colors) + + properties = trk.properties + assert_equal(len(properties), len(self.points)) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in trk: + pass + + # With properties + streamlines = Streamlines(self.points, properties=self.mean_curvature_torsion) + + trk_file = BytesIO() + TrkFile.save(streamlines, trk_file) + + trk_file.seek(0, os.SEEK_SET) + + trk = nib.streamlines.load(trk_file, lazy_load=False) + + hdr = trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(trk), len(self.points)) + + points = trk.points + assert_arrays_equal(points, self.points) + + scalars = trk.scalars + assert_equal(len(scalars), len(self.points)) + + properties = trk.properties + assert_arrays_equal(properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in trk: + pass + + # With scalars and properties + streamlines = Streamlines(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) + + trk_file = BytesIO() + TrkFile.save(streamlines, trk_file) + + trk_file.seek(0, os.SEEK_SET) + + trk = nib.streamlines.load(trk_file, lazy_load=False) + + hdr = trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(trk), len(self.points)) + + points = trk.points + assert_arrays_equal(points, self.points) + + scalars = trk.scalars + assert_arrays_equal(scalars, self.colors) + + properties = trk.properties + assert_arrays_equal(properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in trk: + pass + + def test_write_erroneous_file(self): + # No scalars for every points + scalars = [[(1, 0, 0)], + [(0, 1, 0)], + [(0, 0, 1)]] + + streamlines = Streamlines(self.points, scalars) + assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + # No scalars for every streamlines + scalars = [[(1, 0, 0)]*1, + [(0, 1, 0)]*2] + + streamlines = Streamlines(self.points, scalars) + assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + # Inconsistent number of scalars between points + scalars = [[(1, 0, 0)]*1, + [(0, 1, 0), (0, 1)], + [(0, 0, 1)]*5] + + streamlines = Streamlines(self.points, scalars) + assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) + + # Inconsistent number of scalars between streamlines + scalars = [[(1, 0, 0)]*1, + [(0, 1)]*2, + [(0, 0, 1)]*5] + + streamlines = Streamlines(self.points, scalars) + assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + # Inconsistent number of properties + properties = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + streamlines = Streamlines(self.points, properties=properties) + assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + # No properties for every streamlines + properties = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11, 2.22], dtype="f4")] + streamlines = Streamlines(self.points, properties=properties) + assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + def test_write_file_from_generator(self): + gen_points = (point for point in self.points) + gen_scalars = (scalar for scalar in self.colors) + gen_properties = (prop for prop in self.mean_curvature_torsion) + + assert_raises(HeaderError, Streamlines, points=gen_points, scalars=gen_scalars, properties=gen_properties) + + hdr = {Field.NB_STREAMLINES: len(self.points)} + streamlines = Streamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties, hdr=hdr) + + trk_file = BytesIO() + TrkFile.save(streamlines, trk_file) + + trk_file.seek(0, os.SEEK_SET) + + trk = nib.streamlines.load(trk_file, lazy_load=False) + + hdr = trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(trk), len(self.points)) + + points = trk.points + assert_arrays_equal(points, self.points) + + scalars = trk.scalars + assert_arrays_equal(scalars, self.colors) + properties = trk.properties + assert_arrays_equal(properties, self.mean_curvature_torsion) -def test_load_file(): - # Test loading empty file - # empty_file = pjoin(DATA_PATH, "empty.trk") - # empty_trk = nib.streamlines.load(empty_file) - - # hdr = empty_trk.get_header() - # points = empty_trk.get_points(as_generator=False) - # scalars = empty_trk.get_scalars(as_generator=False) - # properties = empty_trk.get_properties(as_generator=False) - - # assert_equal(hdr[Field.NB_STREAMLINES], 0) - # assert_equal(len(points), 0) - # assert_equal(len(scalars), 0) - # assert_equal(len(properties), 0) - - # for i in empty_trk: pass # Check if we can iterate through the streamlines. - - # Test loading non-empty file - trk_file = pjoin(DATA_PATH, "uncinate.trk") - trk = nib.streamlines.load(trk_file) - - hdr = trk.get_header() - points = trk.get_points(as_generator=False) - 1/0 - scalars = trk.get_scalars(as_generator=False) - properties = trk.get_properties(as_generator=False) - - assert_equal(hdr[Field.NB_STREAMLINES] > 0, True) - assert_equal(len(points) > 0, True) - #assert_equal(len(scalars), 0) - #assert_equal(len(properties), 0) - - for i in trk: pass # Check if we can iterate through the streamlines. - - -if __name__ == "__main__": - test_load_file() \ No newline at end of file + # Check if we can iterate through the streamlines. + for point, scalar, prop in trk: + pass diff --git a/nibabel/streamlines/tests/test_utils.py b/nibabel/streamlines/tests/test_utils.py new file mode 100644 index 0000000000..a0dbcd8b61 --- /dev/null +++ b/nibabel/streamlines/tests/test_utils.py @@ -0,0 +1,323 @@ +import os +import unittest +import tempfile +import numpy as np + +from os.path import join as pjoin + +from nibabel.externals.six import BytesIO + +from nibabel.testing import assert_arrays_equal +from nose.tools import assert_equal, assert_raises, assert_true, assert_false + +import nibabel.streamlines.utils as streamline_utils + +from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import HeaderError +from nibabel.streamlines.header import Field + +DATA_PATH = pjoin(os.path.dirname(__file__), 'data') + + +def test_is_supported(): + # Emtpy file/string + f = BytesIO() + assert_false(streamline_utils.is_supported(f)) + assert_false(streamline_utils.is_supported("")) + + # Valid file without extension + for streamlines_file in streamline_utils.FORMATS.values(): + f = BytesIO() + f.write(streamlines_file.get_magic_number()) + f.seek(0, os.SEEK_SET) + assert_true(streamline_utils.is_supported(f)) + + # Wrong extension but right magic number + for streamlines_file in streamline_utils.FORMATS.values(): + with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: + f.write(streamlines_file.get_magic_number()) + f.seek(0, os.SEEK_SET) + assert_true(streamline_utils.is_supported(f)) + + # Good extension but wrong magic number + for ext, streamlines_file in streamline_utils.FORMATS.items(): + with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: + f.write(b"pass") + f.seek(0, os.SEEK_SET) + assert_false(streamline_utils.is_supported(f)) + + # Wrong extension, string only + f = "my_streamlines.asd" + assert_false(streamline_utils.is_supported(f)) + + # Good extension, string only + for ext, streamlines_file in streamline_utils.FORMATS.items(): + f = "my_streamlines" + ext + assert_true(streamline_utils.is_supported(f)) + + +def test_detect_format(): + # Emtpy file/string + f = BytesIO() + assert_equal(streamline_utils.detect_format(f), None) + assert_equal(streamline_utils.detect_format(""), None) + + # Valid file without extension + for streamlines_file in streamline_utils.FORMATS.values(): + f = BytesIO() + f.write(streamlines_file.get_magic_number()) + f.seek(0, os.SEEK_SET) + assert_equal(streamline_utils.detect_format(f), streamlines_file) + + # Wrong extension but right magic number + for streamlines_file in streamline_utils.FORMATS.values(): + with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: + f.write(streamlines_file.get_magic_number()) + f.seek(0, os.SEEK_SET) + assert_equal(streamline_utils.detect_format(f), streamlines_file) + + # Good extension but wrong magic number + for ext, streamlines_file in streamline_utils.FORMATS.items(): + with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: + f.write(b"pass") + f.seek(0, os.SEEK_SET) + assert_equal(streamline_utils.detect_format(f), None) + + # Wrong extension, string only + f = "my_streamlines.asd" + assert_equal(streamline_utils.detect_format(f), None) + + # Good extension, string only + for ext, streamlines_file in streamline_utils.FORMATS.items(): + f = "my_streamlines" + ext + assert_equal(streamline_utils.detect_format(f), streamlines_file) + + +class TestLoadSave(unittest.TestCase): + # Testing scalars and properties depend on the format. + # See unit tests in the specific format test file. + + def setUp(self): + self.empty_filenames = [pjoin(DATA_PATH, "empty" + ext) for ext in streamline_utils.FORMATS.keys()] + self.simple_filenames = [pjoin(DATA_PATH, "simple" + ext) for ext in streamline_utils.FORMATS.keys()] + self.complex_filenames = [pjoin(DATA_PATH, "complex" + ext) for ext in streamline_utils.FORMATS.keys()] + + self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11, 2.22], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + + def test_load_empty_file(self): + for empty_filename in self.empty_filenames: + empty_streamlines = streamline_utils.load(empty_filename, lazy_load=False) + + hdr = empty_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], 0) + assert_equal(len(empty_streamlines), 0) + + points = empty_streamlines.points + assert_equal(len(points), 0) + + # For an empty file, scalars should be zero regardless of the format. + scalars = empty_streamlines.scalars + assert_equal(len(scalars), 0) + + # For an empty file, properties should be zero regardless of the format. + properties = empty_streamlines.properties + assert_equal(len(properties), 0) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in empty_streamlines: + pass + + def test_load_simple_file(self): + for simple_filename in self.simple_filenames: + simple_streamlines = streamline_utils.load(simple_filename, lazy_load=False) + + hdr = simple_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(simple_streamlines), len(self.points)) + + points = simple_streamlines.points + assert_arrays_equal(points, self.points) + + # Testing scalars and properties depend on the format. + # See unit tests in the specific format test file. + + # Check if we can iterate through the streamlines. + for point, scalar, prop in simple_streamlines: + pass + + # Test lazy_load + simple_streamlines = streamline_utils.load(simple_filename, lazy_load=True) + + hdr = simple_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(simple_streamlines), len(self.points)) + + points = simple_streamlines.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in simple_streamlines: + pass + + def test_load_complex_file(self): + for complex_filename in self.complex_filenames: + complex_streamlines = streamline_utils.load(complex_filename, lazy_load=False) + + hdr = complex_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_streamlines), len(self.points)) + + points = complex_streamlines.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_streamlines: + pass + + complex_streamlines = streamline_utils.load(complex_filename, lazy_load=True) + + hdr = complex_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_streamlines), len(self.points)) + + points = complex_streamlines.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_streamlines: + pass + + def test_save_simple_file(self): + for ext in streamline_utils.FORMATS.keys(): + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + streamlines = Streamlines(self.points) + + streamline_utils.save(streamlines, f.name) + simple_streamlines = streamline_utils.load(f, lazy_load=False) + + hdr = simple_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(simple_streamlines), len(self.points)) + + points = simple_streamlines.points + assert_arrays_equal(points, self.points) + + def test_save_complex_file(self): + for ext in streamline_utils.FORMATS.keys(): + # With scalars + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + streamlines = Streamlines(self.points, scalars=self.colors) + + streamline_utils.save(streamlines, f.name) + complex_streamlines = streamline_utils.load(f, lazy_load=False) + + hdr = complex_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_streamlines), len(self.points)) + + points = complex_streamlines.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_streamlines: + pass + + # With properties + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + streamlines = Streamlines(self.points, properties=self.mean_curvature_torsion) + + streamline_utils.save(streamlines, f.name) + complex_streamlines = streamline_utils.load(f, lazy_load=False) + + hdr = complex_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_streamlines), len(self.points)) + + points = complex_streamlines.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_streamlines: + pass + + # With scalars and properties + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + streamlines = Streamlines(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) + + streamline_utils.save(streamlines, f.name) + complex_streamlines = streamline_utils.load(f, lazy_load=False) + + hdr = complex_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_streamlines), len(self.points)) + + points = complex_streamlines.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_streamlines: + pass + + def test_save_file_from_generator(self): + for ext in streamline_utils.FORMATS.keys(): + # With scalars + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + gen_points = (point for point in self.points) + gen_scalars = (scalar for scalar in self.colors) + gen_properties = (prop for prop in self.mean_curvature_torsion) + + assert_raises(HeaderError, Streamlines, points=gen_points, scalars=gen_scalars, properties=gen_properties) + + hdr = {Field.NB_STREAMLINES: len(self.points)} + streamlines = Streamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties, hdr=hdr) + + streamline_utils.save(streamlines, f.name) + streamlines_loaded = streamline_utils.load(f, lazy_load=False) + + hdr = streamlines_loaded.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(streamlines_loaded), len(self.points)) + + points = streamlines_loaded.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in streamlines_loaded: + pass + + def test_save_file_from_function(self): + for ext in streamline_utils.FORMATS.keys(): + # With scalars + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + gen_points = lambda: (point for point in self.points) + gen_scalars = lambda: (scalar for scalar in self.colors) + gen_properties = lambda: (prop for prop in self.mean_curvature_torsion) + + assert_raises(HeaderError, Streamlines, points=gen_points, scalars=gen_scalars, properties=gen_properties) + + hdr = {Field.NB_STREAMLINES: len(self.points)} + streamlines = Streamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties, hdr=hdr) + + streamline_utils.save(streamlines, f.name) + streamlines_loaded = streamline_utils.load(f, lazy_load=False) + + hdr = streamlines_loaded.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(streamlines_loaded), len(self.points)) + + points = streamlines_loaded.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in streamlines_loaded: + pass diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 6a02397cf6..c256bcc0ec 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -1,107 +1,92 @@ +from __future__ import division + # Documentation available here: # http://www.trackvis.org/docs/?subsect=fileformat -from pdb import set_trace as dbg +from ..externals.six.moves import xrange +import struct import os import warnings + import numpy as np -from numpy.lib.recfunctions import append_fields from nibabel.openers import Opener -from nibabel.volumeutils import (native_code, swapped_code, endian_codes) +from nibabel.volumeutils import (native_code, swapped_code) -from nibabel.streamlines.base_format import DynamicStreamlineFile +from nibabel.streamlines.base_format import Streamlines, StreamlinesFile from nibabel.streamlines.header import Field +from nibabel.streamlines.base_format import DataError, HeaderError # Definition of trackvis header structure. # See http://www.trackvis.org/docs/?subsect=fileformat # See http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html -header_1_dtd = [ - (Field.MAGIC_NUMBER, 'S6'), - (Field.DIMENSIONS, 'h', 3), - (Field.VOXEL_SIZES, 'f4', 3), - (Field.ORIGIN, 'f4', 3), - (Field.NB_SCALARS_PER_POINT, 'h'), - ('scalar_name', 'S20', 10), - (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), - ('property_name', 'S20', 10), - ('reserved', 'S508'), - (Field.VOXEL_ORDER, 'S4'), - ('pad2', 'S4'), - ('image_orientation_patient', 'f4', 6), - ('pad1', 'S2'), - ('invert_x', 'S1'), - ('invert_y', 'S1'), - ('invert_z', 'S1'), - ('swap_xy', 'S1'), - ('swap_yz', 'S1'), - ('swap_zx', 'S1'), - (Field.NB_STREAMLINES, 'i4'), - ('version', 'i4'), - ('hdr_size', 'i4'), - ] +header_1_dtd = [(Field.MAGIC_NUMBER, 'S6'), + (Field.DIMENSIONS, 'h', 3), + (Field.VOXEL_SIZES, 'f4', 3), + (Field.ORIGIN, 'f4', 3), + (Field.NB_SCALARS_PER_POINT, 'h'), + ('scalar_name', 'S20', 10), + (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), + ('property_name', 'S20', 10), + ('reserved', 'S508'), + (Field.VOXEL_ORDER, 'S4'), + ('pad2', 'S4'), + ('image_orientation_patient', 'f4', 6), + ('pad1', 'S2'), + ('invert_x', 'S1'), + ('invert_y', 'S1'), + ('invert_z', 'S1'), + ('swap_xy', 'S1'), + ('swap_yz', 'S1'), + ('swap_zx', 'S1'), + (Field.NB_STREAMLINES, 'i4'), + ('version', 'i4'), + ('hdr_size', 'i4'), + ] # Version 2 adds a 4x4 matrix giving the affine transformtation going # from voxel coordinates in the referenced 3D voxel matrix, to xyz # coordinates (axes L->R, P->A, I->S). IF (0 based) value [3, 3] from # this matrix is 0, this means the matrix is not recorded. -header_2_dtd = [ - (Field.MAGIC_NUMBER, 'S6'), - (Field.DIMENSIONS, 'h', 3), - (Field.VOXEL_SIZES, 'f4', 3), - (Field.ORIGIN, 'f4', 3), - (Field.NB_SCALARS_PER_POINT, 'h'), - ('scalar_name', 'S20', 10), - (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), - ('property_name', 'S20', 10), - (Field.VOXEL_TO_WORLD, 'f4', (4,4)), # new field for version 2 - ('reserved', 'S444'), - (Field.VOXEL_ORDER, 'S4'), - ('pad2', 'S4'), - ('image_orientation_patient', 'f4', 6), - ('pad1', 'S2'), - ('invert_x', 'S1'), - ('invert_y', 'S1'), - ('invert_z', 'S1'), - ('swap_xy', 'S1'), - ('swap_yz', 'S1'), - ('swap_zx', 'S1'), - (Field.NB_STREAMLINES, 'i4'), - ('version', 'i4'), - ('hdr_size', 'i4'), - ] +header_2_dtd = [(Field.MAGIC_NUMBER, 'S6'), + (Field.DIMENSIONS, 'h', 3), + (Field.VOXEL_SIZES, 'f4', 3), + (Field.ORIGIN, 'f4', 3), + (Field.NB_SCALARS_PER_POINT, 'h'), + ('scalar_name', 'S20', 10), + (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), + ('property_name', 'S20', 10), + (Field.VOXEL_TO_WORLD, 'f4', (4, 4)), # new field for version 2 + ('reserved', 'S444'), + (Field.VOXEL_ORDER, 'S4'), + ('pad2', 'S4'), + ('image_orientation_patient', 'f4', 6), + ('pad1', 'S2'), + ('invert_x', 'S1'), + ('invert_y', 'S1'), + ('invert_z', 'S1'), + ('swap_xy', 'S1'), + ('swap_yz', 'S1'), + ('swap_zx', 'S1'), + (Field.NB_STREAMLINES, 'i4'), + ('version', 'i4'), + ('hdr_size', 'i4'), + ] # Full header numpy dtypes header_1_dtype = np.dtype(header_1_dtd) header_2_dtype = np.dtype(header_2_dtd) -class HeaderError(Exception): - pass - - -class DataError(Exception): - pass - - -class TrkFile(DynamicStreamlineFile): - MAGIC_NUMBER = "TRACK" - OFFSET = 1000 +class TrkFile(StreamlinesFile): + ''' Convenience class to encapsulate TRK format. ''' - def __init__(self, hdr, streamlines, scalars, properties): - self.filename = None - - self.hdr = hdr - self.streamlines = streamlines - self.scalars = scalars - self.properties = properties + MAGIC_NUMBER = b"TRACK" + HEADER_SIZE = 1000 - ##### - # Static Methods - ### @classmethod def get_magic_number(cls): - ''' Return TRK's magic number ''' + ''' Return TRK's magic number. ''' return cls.MAGIC_NUMBER @classmethod @@ -111,354 +96,355 @@ def is_correct_format(cls, fileobj): Parameters ---------- fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to TRK file (and ready to read from the beginning - of the TRK header data) + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data) Returns ------- is_correct_format : boolean - Returns True if `fileobj` is in TRK format, False otherwise. + Returns True if `fileobj` is in TRK format. ''' - with Opener(fileobj) as fileobj: - magic_number = fileobj.read(5) - fileobj.seek(-5, os.SEEK_CUR) + with Opener(fileobj) as f: + magic_number = f.read(5) + f.seek(-5, os.SEEK_CUR) return magic_number == cls.MAGIC_NUMBER return False @classmethod - def load(cls, fileobj): - hdr = {} - pos_header = 0 - pos_data = 0 + def sanity_check(cls, fileobj): + ''' Check if data is consistent with information contained in the header. + [Might be useful] - with Opener(fileobj) as fileobj: - pos_header = fileobj.tell() + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data) + + Returns + ------- + is_consistent : boolean + Returns True if data is consistent with header, False otherwise. + ''' + is_consistent = True + + with Opener(fileobj) as f: + start_position = f.tell() - ##### # Read header - ### - hdr_str = fileobj.read(header_2_dtype.itemsize) - hdr = np.fromstring(string=hdr_str, dtype=header_2_dtype) - - if hdr['version'] == 1: - hdr = np.fromstring(string=hdr_str, dtype=header_1_dtype) - elif hdr['version'] == 2: - pass # Nothing more to do here + hdr_str = f.read(header_2_dtype.itemsize) + hdr_rec = np.fromstring(string=hdr_str, dtype=header_2_dtype) + + if hdr_rec['version'] == 1: + hdr_rec = np.fromstring(string=hdr_str, dtype=header_1_dtype) + elif hdr_rec['version'] == 2: + pass # Nothing more to do here else: - raise HeaderError('NiBabel only supports versions 1 and 2.') + warnings.warn("NiBabel only supports versions 1 and 2 (not v.{0}).".format(hdr_rec['version'])) + os.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + return False - # Make header a dictionnary instead of ndarray - hdr = dict(zip(hdr.dtype.names, hdr[0])) + # Convert the first record of `hdr_rec` into a dictionnary + hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0])) # Check endianness - #hdr = append_fields(hdr, Field.ENDIAN, [native_code], usemask=False) hdr[Field.ENDIAN] = native_code - if hdr['hdr_size'] != 1000: + if hdr['hdr_size'] != cls.HEADER_SIZE: hdr[Field.ENDIAN] = swapped_code - hdr = hdr.newbyteorder() - if hdr['hdr_size'] != 1000: - raise HeaderError('Invalid hdr_size of {0}'.format(hdr['hdr_size'])) - - # Add more header fields implied by trk format. - #hdr = append_fields(hdr, Field.WORLD_ORDER, ["RAS"], usemask=False) - hdr[Field.WORLD_ORDER] = "RAS" - - pos_data = fileobj.tell() + hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0].newbyteorder())) # Swap byte order + if hdr['hdr_size'] != cls.HEADER_SIZE: + warnings.warn("Invalid hdr_size: {0} instead of {1}".format(hdr['hdr_size'], cls.HEADER_SIZE)) + os.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + return False + + # By default, the voxel order is LPS. + # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates + if hdr[Field.VOXEL_ORDER] == "": + is_consistent = False + warnings.warn("Voxel order is not specified, will assume 'LPS' since it is Trackvis software's default.") + # Add more header fields implied by trk format. i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") - nb_streamlines = 0 + pts_and_scalars_size = (3 + hdr[Field.NB_SCALARS_PER_POINT]) * f4_dtype.itemsize + properties_size = hdr[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize - #Either verify the number of streamlines specified in the header is correct or - # count the actual number of streamlines in case it was not specified in the header. + #Verify the number of streamlines specified in the header is correct. + nb_streamlines = 0 while True: # Read number of points of the streamline - buf = fileobj.read(i4_dtype.itemsize) + buf = f.read(i4_dtype.itemsize) if buf == '': - break # EOF + break # EOF - nb_pts = np.fromstring(buf, - dtype=i4_dtype, - count=1) + nb_pts = struct.unpack(i4_dtype.str[:-1], buf)[0] - bytes_to_skip = nb_pts * 3 # x, y, z coordinates - bytes_to_skip += nb_pts * hdr[Field.NB_SCALARS_PER_POINT] - bytes_to_skip += hdr[Field.NB_PROPERTIES_PER_STREAMLINE] + bytes_to_skip = nb_pts * pts_and_scalars_size + bytes_to_skip += properties_size # Seek to the next streamline in the file. - fileobj.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) + f.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) nb_streamlines += 1 if hdr[Field.NB_STREAMLINES] != nb_streamlines: - warnings.warn('The number of streamlines specified in header ({1}) does not match ' + - 'the actual number of streamlines contained in this file ({1}). ' + - 'The latter will be used.'.format(hdr[Field.NB_STREAMLINES], nb_streamlines)) + is_consistent = False + warnings.warn(('The number of streamlines specified in header ({1}) does not match ' + 'the actual number of streamlines contained in this file ({1}). ' + ).format(hdr[Field.NB_STREAMLINES], nb_streamlines)) - hdr[Field.NB_STREAMLINES] = nb_streamlines + os.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + + return is_consistent + + @classmethod + def load(cls, fileobj, lazy_load=True): + ''' Loads streamlines from a file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header) - trk_file = cls(hdr, [], [], []) - trk_file.pos_header = pos_header - trk_file.pos_data = pos_data - trk_file.streamlines + lazy_load : boolean + Load streamlines in a lazy manner i.e. they will not be kept + in memory. For postprocessing speed, turn off this option. - return trk_file - # cls(hdr, streamlines, scalars, properties) + Returns + ------- + streamlines : Streamlines object + Returns an object containing streamlines' data and header + information. See 'nibabel.Streamlines'. + ''' + hdr = {} - def get_header(self): - return self.hdr + with Opener(fileobj) as f: + hdr['pos_header'] = f.tell() - def get_points(self, as_generator=False): - self.fileobj.seek(self.pos_data, os.SEEK_SET) - pos = self.pos_data + # Read header + hdr_str = f.read(header_2_dtype.itemsize) + hdr_rec = np.fromstring(string=hdr_str, dtype=header_2_dtype) - i4_dtype = np.dtype(self.hdr[Field.ENDIAN] + "i4") - f4_dtype = np.dtype(self.hdr[Field.ENDIAN] + "f4") + if hdr_rec['version'] == 1: + hdr_rec = np.fromstring(string=hdr_str, dtype=header_1_dtype) + elif hdr_rec['version'] == 2: + pass # Nothing more to do here + else: + raise HeaderError('NiBabel only supports versions 1 and 2.') - for i in range(self.hdr[Field.NB_STREAMLINES]): - # Read number of points of the streamline - nb_pts = np.fromstring(self.fileobj.read(i4_dtype.itemsize), - dtype=i4_dtype, - count=1) - - # Read points of the streamline - pts = np.fromstring(self.fileobj.read(nb_pts * 3 * i4_dtype.itemsize), - dtype=[f4_dtype, f4_dtype, f4_dtype], - count=nb_pts) + # Convert the first record of `hdr_rec` into a dictionnary + hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0])) - pos = self.fileobj.tell() - yield pts - self.fileobj.seek(pos, os.SEEK_SET) - - bytes_to_skip = nb_pts * self.hdr[Field.NB_SCALARS_PER_POINT] - bytes_to_skip += self.hdr[Field.NB_PROPERTIES_PER_STREAMLINE] - - # Seek to the next streamline in the file. - self.fileobj.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) - - ##### - # Methods - ### - - - - -# import os -# import logging -# import numpy as np - -# from tractconverter.formats.header import Header as H - - -# def readBinaryBytes(f, nbBytes, dtype): -# buff = f.read(nbBytes * dtype.itemsize) -# return np.frombuffer(buff, dtype=dtype) - - -# class TRK: -# # self.hdr -# # self.filename -# # self.hdr[H.ENDIAN] -# # self.FIBER_DELIMITER -# # self.END_DELIMITER - -# @staticmethod -# def create(filename, hdr, anatFile=None): -# f = open(filename, 'wb') -# f.write(TRK.MAGIC_NUMBER + "\n") -# f.close() - -# trk = TRK(filename, load=False) -# trk.hdr = hdr -# trk.writeHeader() - -# return trk - -# ##### -# # Methods -# ### -# def __init__(self, filename, anatFile=None, load=True): -# if not TRK._check(filename): -# raise NameError("Not a TRK file.") - -# self.filename = filename -# self.hdr = {} -# if load: -# self._load() - -# def _load(self): -# f = open(self.filename, 'rb') - -# ##### -# # Read header -# ### -# self.hdr[H.MAGIC_NUMBER] = f.read(6) -# self.hdr[H.DIMENSIONS] = np.frombuffer(f.read(6), dtype='i4') -# self.hdr["version"] = self.hdr["version"].astype('>i4') -# self.hdr["hdr_size"] = self.hdr["hdr_size"].astype('>i4') - -# nb_fibers = 0 -# self.hdr[H.NB_POINTS] = 0 - -# #Either verify the number of streamlines specified in the header is correct or -# # count the actual number of streamlines in case it is not specified in the header. -# remainingBytes = os.path.getsize(self.filename) - self.OFFSET -# while remainingBytes > 0: -# # Read points -# nbPoints = readBinaryBytes(f, 1, np.dtype(self.hdr[H.ENDIAN] + "i4"))[0] -# self.hdr[H.NB_POINTS] += nbPoints -# # This seek is used to go to the next points number indication in the file. -# f.seek((nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]) -# + self.hdr[H.NB_PROPERTIES_PER_STREAMLINE]) * 4, 1) # Relative seek -# remainingBytes -= (nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]) -# + self.hdr[H.NB_PROPERTIES_PER_STREAMLINE]) * 4 + 4 -# nb_fibers += 1 - -# if self.hdr[H.NB_FIBERS] != nb_fibers: -# logging.warn('The number of streamlines specified in header ({1}) does not match ' + -# 'the actual number of streamlines contained in this file ({1}). ' + -# 'The latter will be used.'.format(self.hdr[H.NB_FIBERS], nb_fibers)) - -# self.hdr[H.NB_FIBERS] = nb_fibers - -# f.close() - -# def writeHeader(self): -# # Get the voxel size and format it as an array. -# voxel_sizes = np.asarray(self.hdr.get(H.VOXEL_SIZES, (1.0, 1.0, 1.0)), dtype=' 0: -# # Read points -# nbPoints = readBinaryBytes(f, 1, np.dtype(self.hdr[H.ENDIAN] + "i4"))[0] -# ptsAndScalars = readBinaryBytes(f, -# nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]), -# np.dtype(self.hdr[H.ENDIAN] + "f4")) - -# newShape = [-1, 3 + self.hdr[H.NB_SCALARS_PER_POINT]] -# ptsAndScalars = ptsAndScalars.reshape(newShape) - -# pointsWithoutScalars = ptsAndScalars[:, 0:3] -# yield pointsWithoutScalars - -# # For now, we do not process the tract properties, so just skip over them. -# remainingBytes -= nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]) * 4 + 4 -# remainingBytes -= self.hdr[H.NB_PROPERTIES_PER_STREAMLINE] * 4 -# cpt += 1 - -# f.close() - -# def __str__(self): -# text = "" -# text += "MAGIC NUMBER: {0}".format(self.hdr[H.MAGIC_NUMBER]) -# text += "v.{0}".format(self.hdr['version']) -# text += "dim: {0}".format(self.hdr[H.DIMENSIONS]) -# text += "voxel_sizes: {0}".format(self.hdr[H.VOXEL_SIZES]) -# text += "orgin: {0}".format(self.hdr[H.ORIGIN]) -# text += "nb_scalars: {0}".format(self.hdr[H.NB_SCALARS_PER_POINT]) -# text += "scalar_name:\n {0}".format("\n".join(self.hdr['scalar_name'])) -# text += "nb_properties: {0}".format(self.hdr[H.NB_PROPERTIES_PER_STREAMLINE]) -# text += "property_name:\n {0}".format("\n".join(self.hdr['property_name'])) -# text += "vox_to_world: {0}".format(self.hdr[H.VOXEL_TO_WORLD]) -# text += "world_order: {0}".format(self.hdr[H.WORLD_ORDER]) -# text += "voxel_order: {0}".format(self.hdr[H.VOXEL_ORDER]) -# text += "image_orientation_patient: {0}".format(self.hdr['image_orientation_patient']) -# text += "pad1: {0}".format(self.hdr['pad1']) -# text += "pad2: {0}".format(self.hdr['pad2']) -# text += "invert_x: {0}".format(self.hdr['invert_x']) -# text += "invert_y: {0}".format(self.hdr['invert_y']) -# text += "invert_z: {0}".format(self.hdr['invert_z']) -# text += "swap_xy: {0}".format(self.hdr['swap_xy']) -# text += "swap_yz: {0}".format(self.hdr['swap_yz']) -# text += "swap_zx: {0}".format(self.hdr['swap_zx']) -# text += "n_count: {0}".format(self.hdr[H.NB_FIBERS]) -# text += "hdr_size: {0}".format(self.hdr['hdr_size']) -# text += "endianess: {0}".format(self.hdr[H.ENDIAN]) - -# return text + # Check endianness + hdr[Field.ENDIAN] = native_code + if hdr['hdr_size'] != cls.HEADER_SIZE: + hdr[Field.ENDIAN] = swapped_code + hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0].newbyteorder())) # Swap byte order + if hdr['hdr_size'] != cls.HEADER_SIZE: + raise HeaderError('Invalid hdr_size: {0} instead of {1}'.format(hdr['hdr_size'], cls.HEADER_SIZE)) + + # By default, the voxel order is LPS. + # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates + if hdr[Field.VOXEL_ORDER] == "": + hdr[Field.VOXEL_ORDER] = "LPS" + + # Add more header fields implied by trk format. + hdr[Field.WORLD_ORDER] = "RAS" + hdr['pos_data'] = f.tell() + + points = lambda: (x[0] for x in TrkFile._read_data(hdr, fileobj)) + scalars = lambda: (x[1] for x in TrkFile._read_data(hdr, fileobj)) + properties = lambda: (x[2] for x in TrkFile._read_data(hdr, fileobj)) + data = lambda: TrkFile._read_data(hdr, fileobj) + + if lazy_load: + streamlines = Streamlines(points, scalars, properties, hdr=hdr) + streamlines.data = data + return streamlines + + return Streamlines(*zip(*data()), hdr=hdr) + + @staticmethod + def _read_data(hdr, fileobj): + ''' Read streamlines' data from a file-like object using a TRK's header. ''' + i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") + f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") + + with Opener(fileobj) as f: + nb_pts_and_scalars = 3 + int(hdr[Field.NB_SCALARS_PER_POINT]) + pts_and_scalars_size = nb_pts_and_scalars * f4_dtype.itemsize + + slice_pts_and_scalars = lambda data: (data, []) + if hdr[Field.NB_SCALARS_PER_POINT] > 0: + # This is faster than np.split + slice_pts_and_scalars = lambda data: (data[:, :3], data[:, 3:]) + + # Using np.fromfile would be faster, but does not support StringIO + read_pts_and_scalars = lambda nb_pts: slice_pts_and_scalars(np.ndarray(shape=(nb_pts, nb_pts_and_scalars), + dtype=f4_dtype, + buffer=f.read(nb_pts * pts_and_scalars_size))) + + properties_size = int(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) * f4_dtype.itemsize + read_properties = lambda: [] + if hdr[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + read_properties = lambda: np.fromstring(f.read(properties_size), + dtype=f4_dtype, + count=hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + + f.seek(hdr['pos_data'], os.SEEK_SET) + + for i in xrange(hdr[Field.NB_STREAMLINES]): + # Read number of points of the next streamline. + nb_pts = struct.unpack(i4_dtype.str[:-1], f.read(i4_dtype.itemsize))[0] + + # Read streamline's data + pts, scalars = read_pts_and_scalars(nb_pts) + properties = read_properties() + yield pts, scalars, properties + + @classmethod + def get_empty_header(cls): + ''' Return an empty TRK's header. ''' + hdr = np.zeros(1, dtype=header_2_dtype) + + #Default values + hdr[Field.MAGIC_NUMBER] = cls.MAGIC_NUMBER + hdr[Field.VOXEL_SIZES] = (1, 1, 1) + hdr[Field.DIMENSIONS] = (1, 1, 1) + hdr[Field.VOXEL_TO_WORLD] = np.eye(4) + hdr['version'] = 2 + hdr['hdr_size'] = cls.HEADER_SIZE + + return hdr + + @classmethod + def save(cls, streamlines, fileobj): + ''' Saves streamlines to a file-like object. + + Parameters + ---------- + streamlines : Streamlines object + Object containing streamlines' data and header information. + See 'nibabel.Streamlines'. + + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data) + ''' + hdr = cls.get_empty_header() + + #Override hdr's fields by those contain in `streamlines`'s header + for k, v in streamlines.get_header().items(): + if k in header_2_dtype.fields.keys(): + hdr[k] = v + + # Check which endianess to use to write data. + endianess = streamlines.get_header().get(Field.ENDIAN, native_code) + + if endianess == swapped_code: + hdr = hdr.newbyteorder() + + i4_dtype = np.dtype(endianess + "i4") + f4_dtype = np.dtype(endianess + "f4") + + # Keep counts for correcting incoherent fields or warn. + nb_streamlines = 0 + nb_points = 0 + nb_scalars = 0 + nb_properties = 0 + + # Write header + data of streamlines + with Opener(fileobj, mode="wb") as f: + pos = f.tell() + f.write(hdr[0].tostring()) + + for points, scalars, properties in streamlines: + if len(scalars) > 0 and len(scalars) != len(points): + raise DataError("Missing scalars for some points!") + + points = np.array(points, dtype=f4_dtype) + scalars = np.array(scalars, dtype=f4_dtype).reshape((len(points), -1)) + properties = np.array(properties, dtype=f4_dtype) + + data = struct.pack(i4_dtype.str[:-1], len(points)) + data += np.concatenate((points, scalars), axis=1).tostring() + data += properties.tostring() + f.write(data) + + nb_streamlines += 1 + nb_points += len(points) + nb_scalars += scalars.size + nb_properties += len(properties) + + # Either correct or warn if header and data are incoherent. + #TODO: add a warn option as a function parameter + nb_scalars_per_point = nb_scalars / nb_points + nb_properties_per_streamline = nb_properties / nb_streamlines + + # Check for errors + if nb_scalars_per_point != int(nb_scalars_per_point): + raise DataError("Nb. of scalars differs from one point to another!") + + if nb_properties_per_streamline != int(nb_properties_per_streamline): + raise DataError("Nb. of properties differs from one streamline to another!") + + hdr[Field.NB_STREAMLINES] = nb_streamlines + hdr[Field.NB_SCALARS_PER_POINT] = nb_scalars_per_point + hdr[Field.NB_PROPERTIES_PER_STREAMLINE] = nb_properties_per_streamline + + f.seek(pos, os.SEEK_SET) + f.write(hdr[0].tostring()) # Overwrite header with updated one. + + @staticmethod + def pretty_print(streamlines): + ''' Gets a formatted string contaning header's information + relevant to the TRK format. + + Parameters + ---------- + streamlines : Streamlines object + Object containing streamlines' data and header information. + See 'nibabel.Streamlines'. + + Returns + ------- + info : string + Header's information relevant to the TRK format. + ''' + hdr = streamlines.get_header() + + info = "" + info += "MAGIC NUMBER: {0}".format(hdr[Field.MAGIC_NUMBER]) + info += "v.{0}".format(hdr['version']) + info += "dim: {0}".format(hdr[Field.DIMENSIONS]) + info += "voxel_sizes: {0}".format(hdr[Field.VOXEL_SIZES]) + info += "orgin: {0}".format(hdr[Field.ORIGIN]) + info += "nb_scalars: {0}".format(hdr[Field.NB_SCALARS_PER_POINT]) + info += "scalar_name:\n {0}".format("\n".join(hdr['scalar_name'])) + info += "nb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + info += "property_name:\n {0}".format("\n".join(hdr['property_name'])) + info += "vox_to_world: {0}".format(hdr[Field.VOXEL_TO_WORLD]) + info += "world_order: {0}".format(hdr[Field.WORLD_ORDER]) + info += "voxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) + info += "image_orientation_patient: {0}".format(hdr['image_orientation_patient']) + info += "pad1: {0}".format(hdr['pad1']) + info += "pad2: {0}".format(hdr['pad2']) + info += "invert_x: {0}".format(hdr['invert_x']) + info += "invert_y: {0}".format(hdr['invert_y']) + info += "invert_z: {0}".format(hdr['invert_z']) + info += "swap_xy: {0}".format(hdr['swap_xy']) + info += "swap_yz: {0}".format(hdr['swap_yz']) + info += "swap_zx: {0}".format(hdr['swap_zx']) + info += "n_count: {0}".format(hdr[Field.NB_STREAMLINES]) + info += "hdr_size: {0}".format(hdr['hdr_size']) + info += "endianess: {0}".format(hdr[Field.ENDIAN]) + + return info diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 511f27fe01..bb5515988d 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -1,47 +1,135 @@ import os -import nibabel as nib +from ..externals.six import string_types from nibabel.streamlines.trk import TrkFile #from nibabel.streamlines.tck import TckFile #from nibabel.streamlines.vtk import VtkFile -#from nibabel.streamlines.fib import FibFile -# Supported format -FORMATS = {"trk": TrkFile, - #"tck": TckFile, - #"vtk": VtkFile, - #"fib": FibFile, +# List of all supported formats +FORMATS = {".trk": TrkFile, + #".tck": TckFile, + #".vtk": VtkFile, } + def is_supported(fileobj): + ''' Checks if the file-like object if supported by NiBabel. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header) + + Returns + ------- + is_supported : boolean + ''' return detect_format(fileobj) is not None def detect_format(fileobj): + ''' Returns the StreamlinesFile object guessed from the file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header) + + Returns + ------- + streamlines_file : StreamlinesFile object + Object that can be used to manage a streamlines file. + See 'nibabel.streamlines.StreamlinesFile'. + ''' for format in FORMATS.values(): - if format.is_correct_format(fileobj): - return format + try: + if format.is_correct_format(fileobj): + return format + + except IOError: + pass - if isinstance(fileobj, basestring): + if isinstance(fileobj, string_types): _, ext = os.path.splitext(fileobj) return FORMATS.get(ext, None) return None +def load(fileobj, lazy_load=True, anat=None): + ''' Loads streamlines from a file-like object in their native space. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the beginning + of the streamlines file's header) + + Returns + ------- + obj : instance of ``Streamlines`` + Returns an instance of a ``Streamlines`` class containing data and metadata + of streamlines loaded from ``fileobj``. + ''' + + # TODO: Ask everyone what should be the behavior if the anat is provided. + # if anat is None: + # warnings.warn("WARNING: Streamlines will be loaded in their native space (i.e. as they were saved).") + + streamlines_file = detect_format(fileobj) + + if streamlines_file is None: + raise TypeError("Unknown format for 'fileobj': {0}!".format(fileobj)) + + return streamlines_file.load(fileobj, lazy_load=lazy_load) + + +def save(streamlines, filename): + ''' Saves a ``Streamlines`` object to a file + + Parameters + ---------- + streamlines : Streamlines object + Streamlines to be saved (metadata is obtained with the function ``get_header`` of ``streamlines``). + + filename : string + Name of the file where the streamlines will be saved. The format will be guessed from ``filename``. + ''' + + streamlines_file = detect_format(filename) + + if streamlines_file is None: + raise TypeError("Unknown format for 'filename': {0}!".format(filename)) + + streamlines_file.save(streamlines, filename) + + def convert(in_fileobj, out_filename): - in_fileobj = nib.streamlines.load(in_fileobj) - out_format = nib.streamlines.guess_format(out_filename) + ''' Converts one streamlines format to another. + + It does not change the space in which the streamlines are. - hdr = in_fileobj.get_header() - points = in_fileobj.get_points(as_generator=True) - scalars = in_fileobj.get_scalars(as_generator=True) - properties = in_fileobj.get_properties(as_generator=True) + Parameters + ---------- + in_fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header) - out_fileobj = out_format(hdr, points, scalars, properties) - out_fileobj.save(out_filename) + out_filename : string + Name of the file where the streamlines will be saved. The format will + be guessed from ``out_filename``. + ''' + streamlines = load(in_fileobj, lazy_load=True) + save(streamlines, out_filename) +# TODO def change_space(streamline_file, new_point_space): - pass \ No newline at end of file + pass diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index 469570e1d5..28dcf2ede2 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -8,6 +8,10 @@ ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ## ''' Utilities for testing ''' from os.path import dirname, abspath, join as pjoin +from warnings import catch_warnings +from nibabel.externals.six.moves import zip +from numpy.testing import assert_array_equal + import numpy as np from warnings import catch_warnings, simplefilter @@ -56,6 +60,20 @@ def assert_allclose_safely(a, b, match_nans=True): assert_true(np.allclose(a, b)) +def assert_arrays_equal(arrays1, arrays2): + for arr1, arr2 in zip(arrays1, arrays2): + assert_array_equal(arr1, arr2) + + +def assert_streamlines_equal(s1, s2): + assert_equal(s1.get_header(), s2.get_header()) + + for (points1, scalars1, properties1), (points2, scalars2, properties2) in zip(s1, s2): + assert_array_equal(points1, points2) + assert_array_equal(scalars1, scalars2) + assert_array_equal(properties1, properties2) + + class suppress_warnings(catch_warnings): """ Version of ``catch_warnings`` class that suppresses warnings """ diff --git a/setup.py b/setup.py index db0b69866b..9ebaf7747b 100755 --- a/setup.py +++ b/setup.py @@ -89,6 +89,7 @@ def main(**extra_args): 'nibabel.testing', 'nibabel.tests', 'nibabel.benchmarks', + 'nibabel.streamlines', # install nisext as its own package 'nisext', 'nisext.tests'], @@ -103,6 +104,7 @@ def main(**extra_args): pjoin('externals', 'tests', 'data', '*'), pjoin('nicom', 'tests', 'data', '*'), pjoin('gifti', 'tests', 'data', '*'), + pjoin('streamlines', 'tests', 'data', '*'), ]}, scripts = [pjoin('bin', 'parrec2nii'), pjoin('bin', 'nib-ls'), From 532c32c5dc8889a92c56c3927985f0287512cc52 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 4 Mar 2015 10:19:22 -0500 Subject: [PATCH 03/54] Added a LazyStreamlines class and made sure streamlines are in voxel space. --- nibabel/streamlines/__init__.py | 2 +- nibabel/streamlines/base_format.py | 158 ++++++++++++-- nibabel/streamlines/header.py | 51 ++++- nibabel/streamlines/tests/test_base_format.py | 205 ++++++++++++++++-- nibabel/streamlines/tests/test_trk.py | 48 ++-- nibabel/streamlines/tests/test_utils.py | 62 +++--- nibabel/streamlines/trk.py | 135 +++++++++--- nibabel/streamlines/utils.py | 37 +++- 8 files changed, 555 insertions(+), 143 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index cd7cd56a21..7f00864c5f 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,7 +1,7 @@ from nibabel.streamlines.utils import load, save -from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import Streamlines, LazyStreamlines from nibabel.streamlines.header import Field from nibabel.streamlines.trk import TrkFile diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 43e3b8616e..30baf4001e 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -1,3 +1,5 @@ +import numpy as np +from warnings import warn from nibabel.streamlines.header import Field @@ -21,45 +23,126 @@ class Streamlines(object): Parameters ---------- - points : sequence of ndarray of shape (N, 3) + points : list of ndarray of shape (N, 3) Sequence of T streamlines. One streamline is an ndarray of shape (N, 3) where N is the number of points in a streamline. - scalars : sequence of ndarray of shape (N, M) + scalars : list of ndarray of shape (N, M) Sequence of T ndarrays of shape (N, M) where T is the number of streamlines defined by ``points``, N is the number of points for a particular streamline and M is the number of scalars associated to each point (excluding the three coordinates). - properties : sequence of ndarray of shape (P,) + properties : list of ndarray of shape (P,) Sequence of T ndarrays of shape (P,) where T is the number of streamlines defined by ``points``, P is the number of properties - associated to each streamlines. + associated to each streamline. hdr : dict Header containing meta information about the streamlines. For a list of common header's fields to use as keys see `nibabel.streamlines.Field`. ''' - def __init__(self, points=[], scalars=[], properties=[], hdr={}): - self.hdr = hdr + def __init__(self, points=[], scalars=[], properties=[]): #, hdr={}): + # Create basic header from given informations. + self._header = {} + self._header[Field.VOXEL_TO_WORLD] = np.eye(4) self.points = points self.scalars = scalars self.properties = properties - self.data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) - try: - self.length = len(points) - except: - if Field.NB_STREAMLINES in hdr: - self.length = hdr[Field.NB_STREAMLINES] - else: - raise HeaderError(("Neither parameter 'points' nor 'hdr' contain information about" - " number of streamlines. Use key '{0}' to set the number of " - "streamlines in 'hdr'.").format(Field.NB_STREAMLINES)) - def get_header(self): - return self.hdr + @property + def header(self): + return self._header + + @property + def points(self): + return self._points + + @points.setter + def points(self, value): + self._points = value + self._header[Field.NB_STREAMLINES] = len(self.points) + + @property + def scalars(self): + return self._scalars + + @scalars.setter + def scalars(self, value): + self._scalars = value + self._header[Field.NB_SCALARS_PER_POINT] = 0 + if len(self.scalars) > 0: + self._header[Field.NB_SCALARS_PER_POINT] = len(self.scalars[0]) + + @property + def properties(self): + return self._properties + + @properties.setter + def properties(self, value): + self._properties = value + self._header[Field.NB_PROPERTIES_PER_STREAMLINE] = 0 + if len(self.properties) > 0: + self._header[Field.NB_PROPERTIES_PER_STREAMLINE] = len(self.properties[0]) + + def __iter__(self): + return zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) + + def __getitem__(self, idx): + pts = self.points[idx] + scalars = [] + if len(self.scalars) > 0: + scalars = self.scalars[idx] + + properties = [] + if len(self.properties) > 0: + properties = self.properties[idx] + + return pts, scalars, properties + + def __len__(self): + return len(self.points) + + +class LazyStreamlines(Streamlines): + ''' Class containing information about streamlines. + + Streamlines objects have three main properties: ``points``, ``scalars`` + and ``properties``. Streamlines objects can be iterate over producing + tuple of ``points``, ``scalars`` and ``properties`` for each streamline. + + Parameters + ---------- + points : sequence of ndarray of shape (N, 3) + Sequence of T streamlines. One streamline is an ndarray of shape (N, 3) + where N is the number of points in a streamline. + + scalars : sequence of ndarray of shape (N, M) + Sequence of T ndarrays of shape (N, M) where T is the number of + streamlines defined by ``points``, N is the number of points + for a particular streamline and M is the number of scalars + associated to each point (excluding the three coordinates). + + properties : sequence of ndarray of shape (P,) + Sequence of T ndarrays of shape (P,) where T is the number of + streamlines defined by ``points``, P is the number of properties + associated to each streamline. + + hdr : dict + Header containing meta information about the streamlines. For a list + of common header's fields to use as keys see `nibabel.streamlines.Field`. + ''' + def __init__(self, points=[], scalars=[], properties=[], data=None, count=None, getitem=None): #, hdr={}): + super(LazyStreamlines, self).__init__(points, scalars, properties) + + self._data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) + if data is not None: + self._data = data if callable(data) else lambda: data + + self._count = count + self._getitem = getitem @property def points(self): @@ -67,7 +150,7 @@ def points(self): @points.setter def points(self, value): - self._points = value if callable(value) else (lambda: value) + self._points = value if callable(value) else lambda: value @property def scalars(self): @@ -85,11 +168,44 @@ def properties(self): def properties(self, value): self._properties = value if callable(value) else lambda: value + def __getitem__(self, idx): + if self._getitem is None: + raise AttributeError('`LazyStreamlines` does not support indexing.') + + return self._getitem(idx) + def __iter__(self): - return self.data() + return self._data() def __len__(self): - return self.length + # If length is unknown, we'll try to get it as rapidely and accurately as possible. + if self._count is None: + # Length might be contained in the header. + if Field.NB_STREAMLINES in self.header: + return self.header[Field.NB_STREAMLINES] + + if callable(self._count): + # Length might be obtained by re-parsing the file (if streamlines come from one). + self._count = self._count() + + if self._count is None: + try: + # Will work if `points` is a finite sequence (e.g. list, ndarray) + self._count = len(self.points) + except: + pass + + if self._count is None: + # As a last resort, count them by iterating through the list of points (while keeping a copy). + warn("Number of streamlines will be determined manually by looping" + " through the streamlines. Note this will consume any" + " generator used to create this `Streamlines`object. If you" + " know the actual number of streamlines, you might want to" + " set `Field.NB_STREAMLINES` of `self.header` beforehand.") + + return sum(1 for _ in self) + + return self._count class StreamlinesFile: diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index aa40d3bb97..357930c012 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -1,4 +1,7 @@ +import numpy as np +import nibabel from nibabel.orientations import aff2axcodes +from nibabel.spatialimages import SpatialImage class Field: @@ -18,12 +21,23 @@ class Field: ORIGIN = "origin" VOXEL_TO_WORLD = "voxel_to_world" VOXEL_ORDER = "voxel_order" - WORLD_ORDER = "world_order" + #WORLD_ORDER = "world_order" ENDIAN = "endian" +def create_header_from_reference(ref): + if type(ref) is np.ndarray: + return create_header_from_affine(ref) + elif isinstance(ref) is SpatialImage: + return create_header_from_nifti(ref) + + # Assume `ref` is a filename: + img = nibabel.load(ref) + return create_header_from_nifti(img) + + def create_header_from_nifti(img): - ''' Creates a common streamlines' header using a nifti image. + ''' Creates a common streamlines' header using a spatial image. Based on the information of the nifti image a dictionnary is created containing the following keys: `Field.ORIGIN`, `Field.DIMENSIONS`, @@ -32,7 +46,7 @@ def create_header_from_nifti(img): Parameters ---------- - img : Nifti1Image object + img : `SpatialImage` object Image containing information about the anatomy where streamlines were created. @@ -51,7 +65,36 @@ def create_header_from_nifti(img): hdr[Field.DIMENSIONS] = img_header.get_data_shape()[:3] hdr[Field.VOXEL_SIZES] = img_header.get_zooms()[:3] hdr[Field.VOXEL_TO_WORLD] = affine - hdr[Field.WORLD_ORDER] = "RAS" # Nifti space + #hdr[Field.WORLD_ORDER] = "RAS" # Nifti space + hdr[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine)) + + return hdr + + +def create_header_from_affine(affine): + ''' Creates a common streamlines' header using an affine matrix. + + Based on the information of the affine matrix a dictionnary is created + containing the following keys: `Field.ORIGIN`, `Field.DIMENSIONS`, + `Field.VOXEL_SIZES`, `Field.VOXEL_TO_WORLD`, `Field.WORLD_ORDER` + and `Field.VOXEL_ORDER`. + + Parameters + ---------- + affine : 2D array (3,3) | 2D array (4,4) + Affine matrix that transforms streamlines from voxel space to world-space. + + Returns + ------- + hdr : dict + Header containing meta information about streamlines. + ''' + hdr = {} + + hdr[Field.ORIGIN] = affine[:3, -1] + hdr[Field.VOXEL_SIZES] = np.sqrt(np.sum(affine[:3, :3]**2, axis=0)) + hdr[Field.VOXEL_TO_WORLD] = affine + #hdr[Field.WORLD_ORDER] = "RAS" # Nifti space hdr[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine)) return hdr diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index 25956bcfed..624b2d23d5 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -4,15 +4,16 @@ from nibabel.testing import assert_arrays_equal from nose.tools import assert_equal, assert_raises +from numpy.testing import assert_array_equal -from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import Streamlines, LazyStreamlines from nibabel.streamlines.base_format import HeaderError from nibabel.streamlines.header import Field DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') -class TestStreamlines(unittest.TestCase): +class TestLazyStreamlines(unittest.TestCase): def setUp(self): self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") @@ -33,21 +34,21 @@ def setUp(self): def test_streamlines_creation_from_arrays(self): # Empty - streamlines = Streamlines() - assert_equal(len(streamlines), 0) + streamlines = LazyStreamlines() - # TODO: Should Streamlines have a default header? It could have - # NB_STREAMLINES, NB_SCALARS_PER_POINT and NB_PROPERTIES_PER_STREAMLINE - # already set. - hdr = streamlines.get_header() - assert_equal(len(hdr), 0) + # LazyStreamlines have a default header when created from arrays: + # NB_STREAMLINES, NB_SCALARS_PER_POINT, NB_PROPERTIES_PER_STREAMLINE + # and VOXEL_TO_WORLD. + hdr = streamlines.header + assert_equal(len(hdr), 1) + assert_equal(len(streamlines), 0) # Check if we can iterate through the streamlines. for point, scalar, prop in streamlines: pass # Only points - streamlines = Streamlines(points=self.points) + streamlines = LazyStreamlines(points=self.points) assert_equal(len(streamlines), len(self.points)) assert_arrays_equal(streamlines.points, self.points) @@ -56,15 +57,14 @@ def test_streamlines_creation_from_arrays(self): pass # Only scalars - streamlines = Streamlines(scalars=self.colors) - assert_equal(len(streamlines), 0) + streamlines = LazyStreamlines(scalars=self.colors) assert_arrays_equal(streamlines.scalars, self.colors) # TODO: is it a faulty behavior? assert_equal(len(list(streamlines)), len(self.colors)) # Points, scalars and properties - streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) + streamlines = LazyStreamlines(self.points, self.colors, self.mean_curvature_torsion) assert_equal(len(streamlines), len(self.points)) assert_arrays_equal(streamlines.points, self.points) assert_arrays_equal(streamlines.scalars, self.colors) @@ -80,11 +80,11 @@ def test_streamlines_creation_from_generators(self): scalars = (x for x in self.colors) properties = (x for x in self.mean_curvature_torsion) - assert_raises(HeaderError, Streamlines, points, scalars, properties) + streamlines = LazyStreamlines(points, scalars, properties) + + # LazyStreamlines object does not support indexing. + assert_raises(AttributeError, streamlines.__getitem__, 0) - hdr = {Field.NB_STREAMLINES: len(self.points)} - streamlines = Streamlines(points, scalars, properties, hdr) - assert_equal(len(streamlines), len(self.points)) assert_arrays_equal(streamlines.points, self.points) assert_arrays_equal(streamlines.scalars, self.colors) assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) @@ -95,17 +95,31 @@ def test_streamlines_creation_from_generators(self): assert_equal(len(list(streamlines.scalars)), 0) assert_equal(len(list(streamlines.properties)), 0) + # Test function len + points = (x for x in self.points) + streamlines = LazyStreamlines(points, scalars, properties) + + # This will consume generator `points`. + # Note this will produce a warning message. + assert_equal(len(streamlines), len(self.points)) + assert_equal(len(streamlines), 0) + + # It will use `Field.NB_STREAMLINES` if it is in the streamlines header + # Note this won't produce a warning message. + streamlines.header[Field.NB_STREAMLINES] = len(self.points) + assert_equal(len(streamlines), len(self.points)) + def test_streamlines_creation_from_functions(self): # Points, scalars and properties points = lambda: (x for x in self.points) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - assert_raises(HeaderError, Streamlines, points, scalars, properties) + streamlines = LazyStreamlines(points, scalars, properties) + + # LazyStreamlines object does not support indexing. + assert_raises(AttributeError, streamlines.__getitem__, 0) - hdr = {Field.NB_STREAMLINES: len(self.points)} - streamlines = Streamlines(points, scalars, properties, hdr) - assert_equal(len(streamlines), len(self.points)) assert_arrays_equal(streamlines.points, self.points) assert_arrays_equal(streamlines.scalars, self.colors) assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) @@ -115,3 +129,152 @@ def test_streamlines_creation_from_functions(self): assert_arrays_equal(streamlines.points, self.points) assert_arrays_equal(streamlines.scalars, self.colors) assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + + # Test function `len` + # Calling `len` will create a new generator each time. + # Note this will produce a warning message. + assert_equal(len(streamlines), len(self.points)) + assert_equal(len(streamlines), len(self.points)) + + # It will use `Field.NB_STREAMLINES` if it is in the streamlines header + # Note this won't produce a warning message. + streamlines.header[Field.NB_STREAMLINES] = len(self.points) + assert_equal(len(streamlines), len(self.points)) + + def test_len(self): + # Points, scalars and properties + points = lambda: (x for x in self.points) + + # Function `len` is computed differently depending on available information. + # When `points` is a list, `len` will use `len(points)`. + streamlines = LazyStreamlines(points=self.points) + assert_equal(len(streamlines), len(self.points)) + + # When `points` is a generator, `len` will iterate through the streamlines + # and consume the generator. + # TODO: check that it has raised a warning message. + streamlines = LazyStreamlines(points=points()) + assert_equal(len(streamlines), len(self.points)) + assert_equal(len(streamlines), 0) + + # When `points` is a callable object that creates a generator, `len` will iterate + # through the streamlines. + # TODO: check that it has raised a warning message. + streamlines = LazyStreamlines(points=points) + assert_equal(len(streamlines), len(self.points)) + assert_equal(len(streamlines), len(self.points)) + + + # No matter what `points` is, if `Field.NB_STREAMLINES` is set in the header + # `len` returns that value. If not and `count` argument is specified, `len` + # will use that information to return a value. + # TODO: check that no warning messages are raised. + for pts in [self.points, points(), points]: + # `Field.NB_STREAMLINES` is set in the header. + streamlines = LazyStreamlines(points=pts) + streamlines.header[Field.NB_STREAMLINES] = 42 + assert_equal(len(streamlines), 42) + + # `count` is an integer. + streamlines = LazyStreamlines(points=pts, count=42) + assert_equal(len(streamlines), 42) + + # `count` argument is a callable object. + nb_calls = [0] + def count(): + nb_calls[0] += 1 + return 42 + + streamlines = LazyStreamlines(points=points, count=count) + assert_equal(len(streamlines), 42) + assert_equal(len(streamlines), 42) + # Check that the callable object is only called once (caching). + assert_equal(nb_calls[0], 1) + + +class TestStreamlines(unittest.TestCase): + + def setUp(self): + self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") + self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") + self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") + + self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11, 2.22], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + + def test_streamlines_creation_from_arrays(self): + # Empty + streamlines = Streamlines() + assert_equal(len(streamlines), 0) + + # Streamlines have a default header: + # NB_STREAMLINES, NB_SCALARS_PER_POINT, NB_PROPERTIES_PER_STREAMLINE + # and VOXEL_TO_WORLD. + hdr = streamlines.header + assert_equal(len(hdr), 4) + + # Check if we can iterate through the streamlines. + for points, scalars, props in streamlines: + pass + + # Only points + streamlines = Streamlines(points=self.points) + assert_equal(len(streamlines), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + + # Check if we can iterate through the streamlines. + for points, scalars, props in streamlines: + pass + + # Only scalars + streamlines = Streamlines(scalars=self.colors) + assert_equal(len(streamlines), 0) + assert_arrays_equal(streamlines.scalars, self.colors) + + # TODO: is it a faulty behavior? + assert_equal(len(list(streamlines)), len(self.colors)) + + # Points, scalars and properties + streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) + assert_equal(len(streamlines), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + assert_arrays_equal(streamlines.scalars, self.colors) + assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in streamlines: + pass + + # Retrieves streamlines by their index + for i, (points, scalars, props) in enumerate(streamlines): + points_i, scalars_i, props_i = streamlines[i] + assert_array_equal(points_i, points) + assert_array_equal(scalars_i, scalars) + assert_array_equal(props_i, props) + + def test_streamlines_creation_from_generators(self): + # Points, scalars and properties + points = (x for x in self.points) + scalars = (x for x in self.colors) + properties = (x for x in self.mean_curvature_torsion) + + # To create streamlines from generators use LazyStreamlines. + assert_raises(TypeError, Streamlines, points, scalars, properties) + + def test_streamlines_creation_from_functions(self): + # Points, scalars and properties + points = lambda: (x for x in self.points) + scalars = lambda: (x for x in self.colors) + properties = lambda: (x for x in self.mean_curvature_torsion) + + # To create streamlines from functions use LazyStreamlines. + assert_raises(TypeError, Streamlines, points, scalars, properties) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 251be7a61d..04dea5a88e 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -8,7 +8,7 @@ from nose.tools import assert_equal, assert_raises import nibabel as nib -from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import Streamlines, LazyStreamlines from nibabel.streamlines.base_format import DataError, HeaderError from nibabel.streamlines.header import Field from nibabel.streamlines.trk import TrkFile @@ -36,9 +36,9 @@ def setUp(self): np.array([3.11, 3.22], dtype="f4")] def test_load_empty_file(self): - empty_trk = nib.streamlines.load(self.empty_trk_filename, lazy_load=False) + empty_trk = nib.streamlines.load(self.empty_trk_filename, ref=None, lazy_load=False) - hdr = empty_trk.get_header() + hdr = empty_trk.header assert_equal(hdr[Field.NB_STREAMLINES], 0) assert_equal(len(empty_trk), 0) @@ -56,9 +56,9 @@ def test_load_empty_file(self): pass def test_load_simple_file(self): - simple_trk = nib.streamlines.load(self.simple_trk_filename, lazy_load=False) + simple_trk = nib.streamlines.load(self.simple_trk_filename, ref=None, lazy_load=False) - hdr = simple_trk.get_header() + hdr = simple_trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(simple_trk), len(self.points)) @@ -76,9 +76,9 @@ def test_load_simple_file(self): pass # Test lazy_load - simple_trk = nib.streamlines.load(self.simple_trk_filename, lazy_load=True) + simple_trk = nib.streamlines.load(self.simple_trk_filename, ref=None, lazy_load=True) - hdr = simple_trk.get_header() + hdr = simple_trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(simple_trk), len(self.points)) @@ -96,9 +96,9 @@ def test_load_simple_file(self): pass def test_load_complex_file(self): - complex_trk = nib.streamlines.load(self.complex_trk_filename, lazy_load=False) + complex_trk = nib.streamlines.load(self.complex_trk_filename, ref=None, lazy_load=False) - hdr = complex_trk.get_header() + hdr = complex_trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_trk), len(self.points)) @@ -115,9 +115,9 @@ def test_load_complex_file(self): for point, scalar, prop in complex_trk: pass - complex_trk = nib.streamlines.load(self.complex_trk_filename, lazy_load=True) + complex_trk = nib.streamlines.load(self.complex_trk_filename, ref=None, lazy_load=True) - hdr = complex_trk.get_header() + hdr = complex_trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_trk), len(self.points)) @@ -142,9 +142,9 @@ def test_write_simple_file(self): simple_trk_file.seek(0, os.SEEK_SET) - simple_trk = nib.streamlines.load(simple_trk_file, lazy_load=False) + simple_trk = nib.streamlines.load(simple_trk_file, ref=None, lazy_load=False) - hdr = simple_trk.get_header() + hdr = simple_trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(simple_trk), len(self.points)) @@ -170,9 +170,9 @@ def test_write_complex_file(self): trk_file.seek(0, os.SEEK_SET) - trk = nib.streamlines.load(trk_file, lazy_load=False) + trk = nib.streamlines.load(trk_file, ref=None, lazy_load=False) - hdr = trk.get_header() + hdr = trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(trk), len(self.points)) @@ -197,9 +197,9 @@ def test_write_complex_file(self): trk_file.seek(0, os.SEEK_SET) - trk = nib.streamlines.load(trk_file, lazy_load=False) + trk = nib.streamlines.load(trk_file, ref=None, lazy_load=False) - hdr = trk.get_header() + hdr = trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(trk), len(self.points)) @@ -224,9 +224,9 @@ def test_write_complex_file(self): trk_file.seek(0, os.SEEK_SET) - trk = nib.streamlines.load(trk_file, lazy_load=False) + trk = nib.streamlines.load(trk_file, ref=None, lazy_load=False) - hdr = trk.get_header() + hdr = trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(trk), len(self.points)) @@ -293,19 +293,17 @@ def test_write_file_from_generator(self): gen_scalars = (scalar for scalar in self.colors) gen_properties = (prop for prop in self.mean_curvature_torsion) - assert_raises(HeaderError, Streamlines, points=gen_points, scalars=gen_scalars, properties=gen_properties) - - hdr = {Field.NB_STREAMLINES: len(self.points)} - streamlines = Streamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties, hdr=hdr) + streamlines = LazyStreamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties) + #streamlines.header[Field.NB_STREAMLINES] = len(self.points) trk_file = BytesIO() TrkFile.save(streamlines, trk_file) trk_file.seek(0, os.SEEK_SET) - trk = nib.streamlines.load(trk_file, lazy_load=False) + trk = nib.streamlines.load(trk_file, ref=None, lazy_load=False) - hdr = trk.get_header() + hdr = trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(trk), len(self.points)) diff --git a/nibabel/streamlines/tests/test_utils.py b/nibabel/streamlines/tests/test_utils.py index a0dbcd8b61..1e6ae83b3d 100644 --- a/nibabel/streamlines/tests/test_utils.py +++ b/nibabel/streamlines/tests/test_utils.py @@ -12,7 +12,7 @@ import nibabel.streamlines.utils as streamline_utils -from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import Streamlines, LazyStreamlines from nibabel.streamlines.base_format import HeaderError from nibabel.streamlines.header import Field @@ -116,9 +116,9 @@ def setUp(self): def test_load_empty_file(self): for empty_filename in self.empty_filenames: - empty_streamlines = streamline_utils.load(empty_filename, lazy_load=False) + empty_streamlines = streamline_utils.load(empty_filename, None, lazy_load=False) - hdr = empty_streamlines.get_header() + hdr = empty_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], 0) assert_equal(len(empty_streamlines), 0) @@ -139,9 +139,10 @@ def test_load_empty_file(self): def test_load_simple_file(self): for simple_filename in self.simple_filenames: - simple_streamlines = streamline_utils.load(simple_filename, lazy_load=False) + simple_streamlines = streamline_utils.load(simple_filename, None, lazy_load=False) + assert_true(type(simple_streamlines), Streamlines) - hdr = simple_streamlines.get_header() + hdr = simple_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(simple_streamlines), len(self.points)) @@ -156,11 +157,12 @@ def test_load_simple_file(self): pass # Test lazy_load - simple_streamlines = streamline_utils.load(simple_filename, lazy_load=True) + simple_streamlines = streamline_utils.load(simple_filename, None, lazy_load=True) + assert_true(type(simple_streamlines), LazyStreamlines) - hdr = simple_streamlines.get_header() + hdr = simple_streamlines.header + assert_true(Field.NB_STREAMLINES in hdr) assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) - assert_equal(len(simple_streamlines), len(self.points)) points = simple_streamlines.points assert_arrays_equal(points, self.points) @@ -171,9 +173,9 @@ def test_load_simple_file(self): def test_load_complex_file(self): for complex_filename in self.complex_filenames: - complex_streamlines = streamline_utils.load(complex_filename, lazy_load=False) + complex_streamlines = streamline_utils.load(complex_filename, None, lazy_load=False) - hdr = complex_streamlines.get_header() + hdr = complex_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_streamlines), len(self.points)) @@ -184,9 +186,9 @@ def test_load_complex_file(self): for point, scalar, prop in complex_streamlines: pass - complex_streamlines = streamline_utils.load(complex_filename, lazy_load=True) + complex_streamlines = streamline_utils.load(complex_filename, None, lazy_load=True) - hdr = complex_streamlines.get_header() + hdr = complex_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_streamlines), len(self.points)) @@ -203,9 +205,9 @@ def test_save_simple_file(self): streamlines = Streamlines(self.points) streamline_utils.save(streamlines, f.name) - simple_streamlines = streamline_utils.load(f, lazy_load=False) + simple_streamlines = streamline_utils.load(f, None, lazy_load=False) - hdr = simple_streamlines.get_header() + hdr = simple_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(simple_streamlines), len(self.points)) @@ -219,9 +221,9 @@ def test_save_complex_file(self): streamlines = Streamlines(self.points, scalars=self.colors) streamline_utils.save(streamlines, f.name) - complex_streamlines = streamline_utils.load(f, lazy_load=False) + complex_streamlines = streamline_utils.load(f, None, lazy_load=False) - hdr = complex_streamlines.get_header() + hdr = complex_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_streamlines), len(self.points)) @@ -237,9 +239,9 @@ def test_save_complex_file(self): streamlines = Streamlines(self.points, properties=self.mean_curvature_torsion) streamline_utils.save(streamlines, f.name) - complex_streamlines = streamline_utils.load(f, lazy_load=False) + complex_streamlines = streamline_utils.load(f, None, lazy_load=False) - hdr = complex_streamlines.get_header() + hdr = complex_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_streamlines), len(self.points)) @@ -255,9 +257,9 @@ def test_save_complex_file(self): streamlines = Streamlines(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) streamline_utils.save(streamlines, f.name) - complex_streamlines = streamline_utils.load(f, lazy_load=False) + complex_streamlines = streamline_utils.load(f, None, lazy_load=False) - hdr = complex_streamlines.get_header() + hdr = complex_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_streamlines), len(self.points)) @@ -276,15 +278,13 @@ def test_save_file_from_generator(self): gen_scalars = (scalar for scalar in self.colors) gen_properties = (prop for prop in self.mean_curvature_torsion) - assert_raises(HeaderError, Streamlines, points=gen_points, scalars=gen_scalars, properties=gen_properties) - - hdr = {Field.NB_STREAMLINES: len(self.points)} - streamlines = Streamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties, hdr=hdr) + streamlines = LazyStreamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties) + #streamlines.hdr[Field.NB_STREAMLINES] = len(self.points) streamline_utils.save(streamlines, f.name) - streamlines_loaded = streamline_utils.load(f, lazy_load=False) + streamlines_loaded = streamline_utils.load(f, None, lazy_load=False) - hdr = streamlines_loaded.get_header() + hdr = streamlines_loaded.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(streamlines_loaded), len(self.points)) @@ -303,15 +303,13 @@ def test_save_file_from_function(self): gen_scalars = lambda: (scalar for scalar in self.colors) gen_properties = lambda: (prop for prop in self.mean_curvature_torsion) - assert_raises(HeaderError, Streamlines, points=gen_points, scalars=gen_scalars, properties=gen_properties) - - hdr = {Field.NB_STREAMLINES: len(self.points)} - streamlines = Streamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties, hdr=hdr) + streamlines = LazyStreamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties) + #streamlines.hdr[Field.NB_STREAMLINES] = len(self.points) streamline_utils.save(streamlines, f.name) - streamlines_loaded = streamline_utils.load(f, lazy_load=False) + streamlines_loaded = streamline_utils.load(f, None, lazy_load=False) - hdr = streamlines_loaded.get_header() + hdr = streamlines_loaded.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(streamlines_loaded), len(self.points)) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index c256bcc0ec..28e27896e1 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -13,7 +13,7 @@ from nibabel.openers import Opener from nibabel.volumeutils import (native_code, swapped_code) -from nibabel.streamlines.base_format import Streamlines, StreamlinesFile +from nibabel.streamlines.base_format import Streamlines, LazyStreamlines, StreamlinesFile from nibabel.streamlines.header import Field from nibabel.streamlines.base_format import DataError, HeaderError @@ -144,7 +144,7 @@ def sanity_check(cls, fileobj): pass # Nothing more to do here else: warnings.warn("NiBabel only supports versions 1 and 2 (not v.{0}).".format(hdr_rec['version'])) - os.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. return False # Convert the first record of `hdr_rec` into a dictionnary @@ -157,7 +157,7 @@ def sanity_check(cls, fileobj): hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0].newbyteorder())) # Swap byte order if hdr['hdr_size'] != cls.HEADER_SIZE: warnings.warn("Invalid hdr_size: {0} instead of {1}".format(hdr['hdr_size'], cls.HEADER_SIZE)) - os.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. return False # By default, the voxel order is LPS. @@ -166,7 +166,6 @@ def sanity_check(cls, fileobj): is_consistent = False warnings.warn("Voxel order is not specified, will assume 'LPS' since it is Trackvis software's default.") - # Add more header fields implied by trk format. i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") @@ -198,12 +197,12 @@ def sanity_check(cls, fileobj): 'the actual number of streamlines contained in this file ({1}). ' ).format(hdr[Field.NB_STREAMLINES], nb_streamlines)) - os.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. return is_consistent @classmethod - def load(cls, fileobj, lazy_load=True): + def load(cls, fileobj, hdr={}, lazy_load=False): ''' Loads streamlines from a file-like object. Parameters @@ -213,9 +212,11 @@ def load(cls, fileobj, lazy_load=True): pointing to TRK file (and ready to read from the beginning of the TRK header) - lazy_load : boolean + hdr : dict (optional) + + lazy_load : boolean (optional) Load streamlines in a lazy manner i.e. they will not be kept - in memory. For postprocessing speed, turn off this option. + in memory. Returns ------- @@ -223,11 +224,7 @@ def load(cls, fileobj, lazy_load=True): Returns an object containing streamlines' data and header information. See 'nibabel.Streamlines'. ''' - hdr = {} - with Opener(fileobj) as f: - hdr['pos_header'] = f.tell() - # Read header hdr_str = f.read(header_2_dtype.itemsize) hdr_rec = np.fromstring(string=hdr_str, dtype=header_2_dtype) @@ -235,12 +232,12 @@ def load(cls, fileobj, lazy_load=True): if hdr_rec['version'] == 1: hdr_rec = np.fromstring(string=hdr_str, dtype=header_1_dtype) elif hdr_rec['version'] == 2: - pass # Nothing more to do here + pass # Nothing more to do else: raise HeaderError('NiBabel only supports versions 1 and 2.') # Convert the first record of `hdr_rec` into a dictionnary - hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0])) + hdr.update(dict(zip(hdr_rec.dtype.names, hdr_rec[0]))) # Check endianness hdr[Field.ENDIAN] = native_code @@ -255,21 +252,30 @@ def load(cls, fileobj, lazy_load=True): if hdr[Field.VOXEL_ORDER] == "": hdr[Field.VOXEL_ORDER] = "LPS" - # Add more header fields implied by trk format. - hdr[Field.WORLD_ORDER] = "RAS" + # Keep the file position where the data begin. hdr['pos_data'] = f.tell() + # If 'count' field is 0, i.e. not provided, we have to loop until the EOF. + if hdr[Field.NB_STREAMLINES] == 0: + del hdr[Field.NB_STREAMLINES] + points = lambda: (x[0] for x in TrkFile._read_data(hdr, fileobj)) scalars = lambda: (x[1] for x in TrkFile._read_data(hdr, fileobj)) properties = lambda: (x[2] for x in TrkFile._read_data(hdr, fileobj)) data = lambda: TrkFile._read_data(hdr, fileobj) if lazy_load: - streamlines = Streamlines(points, scalars, properties, hdr=hdr) - streamlines.data = data - return streamlines + count = TrkFile._count(hdr, fileobj) + if Field.NB_STREAMLINES in hdr: + count = hdr[Field.NB_STREAMLINES] + + streamlines = LazyStreamlines(points, scalars, properties, data=data, count=count) + else: + streamlines = Streamlines(*zip(*data())) - return Streamlines(*zip(*data()), hdr=hdr) + # Set available header's information + streamlines.header.update(hdr) + return streamlines @staticmethod def _read_data(hdr, fileobj): @@ -278,6 +284,8 @@ def _read_data(hdr, fileobj): f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") with Opener(fileobj) as f: + start_position = f.tell() + nb_pts_and_scalars = 3 + int(hdr[Field.NB_SCALARS_PER_POINT]) pts_and_scalars_size = nb_pts_and_scalars * f4_dtype.itemsize @@ -298,20 +306,79 @@ def _read_data(hdr, fileobj): dtype=f4_dtype, count=hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + # Set the file position at the beginning of the data. f.seek(hdr['pos_data'], os.SEEK_SET) - for i in xrange(hdr[Field.NB_STREAMLINES]): + #for i in xrange(hdr[Field.NB_STREAMLINES]): + nb_streamlines = hdr.get(Field.NB_STREAMLINES, np.inf) + i = 0 + while i < nb_streamlines: + nb_pts_str = f.read(i4_dtype.itemsize) + + # Check if we reached EOF + if len(nb_pts_str) == 0: + break + # Read number of points of the next streamline. - nb_pts = struct.unpack(i4_dtype.str[:-1], f.read(i4_dtype.itemsize))[0] + nb_pts = struct.unpack(i4_dtype.str[:-1], nb_pts_str)[0] # Read streamline's data pts, scalars = read_pts_and_scalars(nb_pts) properties = read_properties() + + # TRK's streamlines are in 'voxelmm' space, we send them to voxel space. + pts = pts / hdr[Field.VOXEL_SIZES] + yield pts, scalars, properties + i += 1 + + # In case the 'count' field was not provided. + hdr[Field.NB_STREAMLINES] = i + + # Set the file position where it was. + f.seek(start_position, os.SEEK_CUR) + + @staticmethod + def _count(hdr, fileobj): + ''' Count streamlines from a file-like object using a TRK's header. ''' + nb_streamlines = 0 + + with Opener(fileobj) as f: + start_position = f.tell() + + i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") + f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") + + pts_and_scalars_size = (3 + hdr[Field.NB_SCALARS_PER_POINT]) * f4_dtype.itemsize + properties_size = hdr[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize + + # Set the file position at the beginning of the data. + f.seek(hdr['pos_data'], os.SEEK_SET) + + # Count the actual number of streamlines. + while True: + # Read number of points of the streamline + buf = f.read(i4_dtype.itemsize) + + if buf == '': + break # EOF + + nb_pts = struct.unpack(i4_dtype.str[:-1], buf)[0] + bytes_to_skip = nb_pts * pts_and_scalars_size + bytes_to_skip += properties_size + + # Seek to the next streamline in the file. + f.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) + + nb_streamlines += 1 + + f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + + return nb_streamlines @classmethod - def get_empty_header(cls): - ''' Return an empty TRK's header. ''' + def create_empty_header(cls): + ''' Return an empty TRK compliant header. ''' hdr = np.zeros(1, dtype=header_2_dtype) #Default values @@ -338,16 +405,22 @@ def save(cls, streamlines, fileobj): If string, a filename; otherwise an open file-like object pointing to TRK file (and ready to read from the beginning of the TRK header data) + + hdr : dict (optional) + + Notes + ----- + Streamlines are assumed to be in voxel space. ''' - hdr = cls.get_empty_header() + hdr = cls.create_empty_header() #Override hdr's fields by those contain in `streamlines`'s header - for k, v in streamlines.get_header().items(): + for k, v in streamlines.header.items(): if k in header_2_dtype.fields.keys(): hdr[k] = v # Check which endianess to use to write data. - endianess = streamlines.get_header().get(Field.ENDIAN, native_code) + endianess = streamlines.header.get(Field.ENDIAN, native_code) if endianess == swapped_code: hdr = hdr.newbyteorder() @@ -364,6 +437,7 @@ def save(cls, streamlines, fileobj): # Write header + data of streamlines with Opener(fileobj, mode="wb") as f: pos = f.tell() + # Write header f.write(hdr[0].tostring()) for points, scalars, properties in streamlines: @@ -374,6 +448,9 @@ def save(cls, streamlines, fileobj): scalars = np.array(scalars, dtype=f4_dtype).reshape((len(points), -1)) properties = np.array(properties, dtype=f4_dtype) + # TRK's streamlines need to be in 'voxelmm' space + points = points * hdr[Field.VOXEL_SIZES] + data = struct.pack(i4_dtype.str[:-1], len(points)) data += np.concatenate((points, scalars), axis=1).tostring() data += properties.tostring() @@ -419,7 +496,7 @@ def pretty_print(streamlines): info : string Header's information relevant to the TRK format. ''' - hdr = streamlines.get_header() + hdr = streamlines.header info = "" info += "MAGIC NUMBER: {0}".format(hdr[Field.MAGIC_NUMBER]) @@ -432,7 +509,7 @@ def pretty_print(streamlines): info += "nb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) info += "property_name:\n {0}".format("\n".join(hdr['property_name'])) info += "vox_to_world: {0}".format(hdr[Field.VOXEL_TO_WORLD]) - info += "world_order: {0}".format(hdr[Field.WORLD_ORDER]) + #info += "world_order: {0}".format(hdr[Field.WORLD_ORDER]) info += "voxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) info += "image_orientation_patient: {0}".format(hdr['image_orientation_patient']) info += "pad1: {0}".format(hdr['pad1']) diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index bb5515988d..0afc7bccda 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -2,6 +2,9 @@ from ..externals.six import string_types +from nibabel.streamlines import header +from nibabel.streamlines.base_format import LazyStreamlines + from nibabel.streamlines.trk import TrkFile #from nibabel.streamlines.tck import TckFile #from nibabel.streamlines.vtk import VtkFile @@ -61,8 +64,8 @@ def detect_format(fileobj): return None -def load(fileobj, lazy_load=True, anat=None): - ''' Loads streamlines from a file-like object in their native space. +def load(fileobj, ref, lazy_load=False): + ''' Loads streamlines from a file-like object in voxel space. Parameters ---------- @@ -71,26 +74,32 @@ def load(fileobj, lazy_load=True, anat=None): pointing to a streamlines file (and ready to read from the beginning of the streamlines file's header) + ref : filename | `Nifti1Image` object | 2D array (3,3) | 2D array (4,4) | None + Reference space where streamlines have been created. + + lazy_load : boolean (optional) + Load streamlines in a lazy manner i.e. they will not be kept + in memory. + Returns ------- obj : instance of ``Streamlines`` Returns an instance of a ``Streamlines`` class containing data and metadata of streamlines loaded from ``fileobj``. ''' - - # TODO: Ask everyone what should be the behavior if the anat is provided. - # if anat is None: - # warnings.warn("WARNING: Streamlines will be loaded in their native space (i.e. as they were saved).") - streamlines_file = detect_format(fileobj) if streamlines_file is None: raise TypeError("Unknown format for 'fileobj': {0}!".format(fileobj)) - return streamlines_file.load(fileobj, lazy_load=lazy_load) + hdr = {} + if ref is not None: + hdr = header.create_header_from_reference(ref) + return streamlines_file.load(fileobj, hdr=hdr, lazy_load=lazy_load) -def save(streamlines, filename): + +def save(streamlines, filename, ref=None): ''' Saves a ``Streamlines`` object to a file Parameters @@ -100,13 +109,21 @@ def save(streamlines, filename): filename : string Name of the file where the streamlines will be saved. The format will be guessed from ``filename``. - ''' + ref : filename | `Nifti1Image` object | 2D array (3,3) | 2D array (4,4) | None (optional) + Reference space the streamlines belong to. Default: get ref from `streamlines.header`. + ''' streamlines_file = detect_format(filename) if streamlines_file is None: raise TypeError("Unknown format for 'filename': {0}!".format(filename)) + if ref is not None: + # Create a `LazyStreamlines` from `streamlines` but using the new reference image. + streamlines = LazyStreamlines(data=iter(streamlines)) + streamlines.header.update(streamlines.header) + streamlines.header.update(header.create_header_from_reference(ref)) + streamlines_file.save(streamlines, filename) From 0f7b2a693d95403e644a82d5ad075907f4498e5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 4 Mar 2015 14:50:28 -0500 Subject: [PATCH 04/54] Fixed bug in TRK count function --- nibabel/streamlines/tests/test_trk.py | 11 +++++++++++ nibabel/streamlines/trk.py | 10 +++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 04dea5a88e..cb32cb6334 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -134,6 +134,17 @@ def test_load_complex_file(self): for point, scalar, prop in complex_trk: pass + def test_load_file_with_no_count(self): + trk_file = open(self.simple_trk_filename, 'rb').read() + # Simulate a TRK file where count was not provided. + count = np.array(0, dtype="int32").tostring() + trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] + streamlines = nib.streamlines.load(BytesIO(trk_file), ref=None, lazy_load=False) + assert_equal(len(streamlines), len(self.points)) + + streamlines = nib.streamlines.load(BytesIO(trk_file), ref=None, lazy_load=True) + assert_equal(len(streamlines), len(self.points)) + def test_write_simple_file(self): streamlines = Streamlines(self.points) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 28e27896e1..dbeda618ae 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -178,7 +178,7 @@ def sanity_check(cls, fileobj): # Read number of points of the streamline buf = f.read(i4_dtype.itemsize) - if buf == '': + if len(buf) == 0: break # EOF nb_pts = struct.unpack(i4_dtype.str[:-1], buf)[0] @@ -187,7 +187,7 @@ def sanity_check(cls, fileobj): bytes_to_skip += properties_size # Seek to the next streamline in the file. - f.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) + f.seek(bytes_to_skip, os.SEEK_CUR) nb_streamlines += 1 @@ -265,7 +265,7 @@ def load(cls, fileobj, hdr={}, lazy_load=False): data = lambda: TrkFile._read_data(hdr, fileobj) if lazy_load: - count = TrkFile._count(hdr, fileobj) + count = lambda: TrkFile._count(hdr, fileobj) if Field.NB_STREAMLINES in hdr: count = hdr[Field.NB_STREAMLINES] @@ -360,7 +360,7 @@ def _count(hdr, fileobj): # Read number of points of the streamline buf = f.read(i4_dtype.itemsize) - if buf == '': + if len(buf) == 0: break # EOF nb_pts = struct.unpack(i4_dtype.str[:-1], buf)[0] @@ -368,7 +368,7 @@ def _count(hdr, fileobj): bytes_to_skip += properties_size # Seek to the next streamline in the file. - f.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) + f.seek(bytes_to_skip, os.SEEK_CUR) nb_streamlines += 1 From 66d5cd1af62a62a494ff58e1f56a69b4de978cd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 4 Mar 2015 23:29:03 -0500 Subject: [PATCH 05/54] Removed unused function check_integrity --- nibabel/streamlines/trk.py | 90 +------------------------------------- 1 file changed, 1 insertion(+), 89 deletions(-) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index dbeda618ae..134beaa9d6 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -112,95 +112,6 @@ def is_correct_format(cls, fileobj): return False - @classmethod - def sanity_check(cls, fileobj): - ''' Check if data is consistent with information contained in the header. - [Might be useful] - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to TRK file (and ready to read from the beginning - of the TRK header data) - - Returns - ------- - is_consistent : boolean - Returns True if data is consistent with header, False otherwise. - ''' - is_consistent = True - - with Opener(fileobj) as f: - start_position = f.tell() - - # Read header - hdr_str = f.read(header_2_dtype.itemsize) - hdr_rec = np.fromstring(string=hdr_str, dtype=header_2_dtype) - - if hdr_rec['version'] == 1: - hdr_rec = np.fromstring(string=hdr_str, dtype=header_1_dtype) - elif hdr_rec['version'] == 2: - pass # Nothing more to do here - else: - warnings.warn("NiBabel only supports versions 1 and 2 (not v.{0}).".format(hdr_rec['version'])) - f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. - return False - - # Convert the first record of `hdr_rec` into a dictionnary - hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0])) - - # Check endianness - hdr[Field.ENDIAN] = native_code - if hdr['hdr_size'] != cls.HEADER_SIZE: - hdr[Field.ENDIAN] = swapped_code - hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0].newbyteorder())) # Swap byte order - if hdr['hdr_size'] != cls.HEADER_SIZE: - warnings.warn("Invalid hdr_size: {0} instead of {1}".format(hdr['hdr_size'], cls.HEADER_SIZE)) - f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. - return False - - # By default, the voxel order is LPS. - # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates - if hdr[Field.VOXEL_ORDER] == "": - is_consistent = False - warnings.warn("Voxel order is not specified, will assume 'LPS' since it is Trackvis software's default.") - - i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") - f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") - - pts_and_scalars_size = (3 + hdr[Field.NB_SCALARS_PER_POINT]) * f4_dtype.itemsize - properties_size = hdr[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize - - #Verify the number of streamlines specified in the header is correct. - nb_streamlines = 0 - while True: - # Read number of points of the streamline - buf = f.read(i4_dtype.itemsize) - - if len(buf) == 0: - break # EOF - - nb_pts = struct.unpack(i4_dtype.str[:-1], buf)[0] - - bytes_to_skip = nb_pts * pts_and_scalars_size - bytes_to_skip += properties_size - - # Seek to the next streamline in the file. - f.seek(bytes_to_skip, os.SEEK_CUR) - - nb_streamlines += 1 - - if hdr[Field.NB_STREAMLINES] != nb_streamlines: - is_consistent = False - warnings.warn(('The number of streamlines specified in header ({1}) does not match ' - 'the actual number of streamlines contained in this file ({1}). ' - ).format(hdr[Field.NB_STREAMLINES], nb_streamlines)) - - f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. - - return is_consistent - @classmethod def load(cls, fileobj, hdr={}, lazy_load=False): ''' Loads streamlines from a file-like object. @@ -250,6 +161,7 @@ def load(cls, fileobj, hdr={}, lazy_load=False): # By default, the voxel order is LPS. # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates if hdr[Field.VOXEL_ORDER] == "": + warnings.warn("Voxel order is not specified, will assume 'LPS' since it is Trackvis software's default.") hdr[Field.VOXEL_ORDER] = "LPS" # Keep the file position where the data begin. From 656a664cfcf59daec8d4952aa9440d8c37fcf133 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 16 Apr 2015 23:10:09 -0400 Subject: [PATCH 06/54] Refactored code and tests --- nibabel/streamlines/__init__.py | 140 ++++- nibabel/streamlines/base_format.py | 276 +++++---- nibabel/streamlines/header.py | 163 +++--- nibabel/streamlines/tests/data/complex.trk | Bin 1228 -> 1228 bytes nibabel/streamlines/tests/data/empty.trk | Bin 1000 -> 1000 bytes nibabel/streamlines/tests/data/simple.trk | Bin 1108 -> 1108 bytes nibabel/streamlines/tests/test_base_format.py | 377 +++++++------ nibabel/streamlines/tests/test_header.py | 37 ++ nibabel/streamlines/tests/test_streamlines.py | 229 ++++++++ nibabel/streamlines/tests/test_trk.py | 325 ++++------- nibabel/streamlines/tests/test_utils.py | 332 +---------- nibabel/streamlines/trk.py | 527 ++++++++++-------- nibabel/streamlines/utils.py | 161 +----- nibabel/testing/__init__.py | 15 +- 14 files changed, 1349 insertions(+), 1233 deletions(-) create mode 100644 nibabel/streamlines/tests/test_header.py create mode 100644 nibabel/streamlines/tests/test_streamlines.py diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 7f00864c5f..7f3d7dcdf2 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,9 +1,137 @@ +from .header import Field +from .base_format import Streamlines, LazyStreamlines -from nibabel.streamlines.utils import load, save +from nibabel.streamlines.trk import TrkFile +#from nibabel.streamlines.tck import TckFile +#from nibabel.streamlines.vtk import VtkFile -from nibabel.streamlines.base_format import Streamlines, LazyStreamlines -from nibabel.streamlines.header import Field +# List of all supported formats +FORMATS = {".trk": TrkFile, + #".tck": TckFile, + #".vtk": VtkFile, + } -from nibabel.streamlines.trk import TrkFile -#from nibabel.streamlines.trk import TckFile -#from nibabel.streamlines.trk import VtkFile + +def is_supported(fileobj): + ''' Checks if the file-like object if supported by NiBabel. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header) + + Returns + ------- + is_supported : boolean + ''' + return detect_format(fileobj) is not None + + +def detect_format(fileobj): + ''' Returns the StreamlinesFile object guessed from the file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header) + + Returns + ------- + streamlines_file : StreamlinesFile object + Object that can be used to manage a streamlines file. + See 'nibabel.streamlines.StreamlinesFile'. + ''' + for format in FORMATS.values(): + try: + if format.is_correct_format(fileobj): + return format + + except IOError: + pass + + import os + from ..externals.six import string_types + if isinstance(fileobj, string_types): + _, ext = os.path.splitext(fileobj) + return FORMATS.get(ext, None) + + return None + + +def load(fileobj, ref, lazy_load=False): + ''' Loads streamlines from a file-like object in voxel space. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the beginning + of the streamlines file's header). + + ref : filename | `Nifti1Image` object | 2D array (4,4) + Reference space where streamlines will live in `fileobj`. + + lazy_load : boolean (optional) + Load streamlines in a lazy manner i.e. they will not be kept + in memory. + + Returns + ------- + obj : instance of `Streamlines` + Returns an instance of a `Streamlines` class containing data and metadata + of streamlines loaded from `fileobj`. + ''' + streamlines_file = detect_format(fileobj) + + if streamlines_file is None: + raise TypeError("Unknown format for 'fileobj': {0}!".format(fileobj)) + + return streamlines_file.load(fileobj, ref, lazy_load=lazy_load) + + +def save(streamlines, filename, ref=None): + ''' Saves a `Streamlines` object to a file + + Parameters + ---------- + streamlines : `Streamlines` object + Streamlines to be saved. + + filename : str + Name of the file where the streamlines will be saved. The format will + be guessed from `filename`. + + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines will live in `fileobj`. + ''' + streamlines_file = detect_format(filename) + + if streamlines_file is None: + raise TypeError("Unknown streamlines file format: '{0}'!".format(filename)) + + streamlines_file.save(streamlines, filename, ref) + + +def convert(in_fileobj, out_filename, ref): + ''' Converts a streamlines file to another format. + + Parameters + ---------- + in_fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header). + + out_filename : str + Name of the file where the streamlines will be saved. The format will + be guessed from `out_filename`. + + ref : filename | `Nifti1Image` object | 2D array (4,4) + Reference space where streamlines live in `fileobj`. + ''' + streamlines = load(in_fileobj, ref, lazy_load=True) + save(streamlines, out_filename) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 30baf4001e..a94f5728ad 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -1,9 +1,18 @@ -import numpy as np from warnings import warn -from nibabel.streamlines.header import Field +from nibabel.externals.six.moves import zip_longest +from nibabel.affines import apply_affine -from ..externals.six.moves import zip_longest +from .header import StreamlinesHeader +from .utils import pop + + +class UsageWarning(Warning): + pass + + +class HeaderWarning(Warning): + pass class HeaderError(Exception): @@ -37,20 +46,12 @@ class Streamlines(object): Sequence of T ndarrays of shape (P,) where T is the number of streamlines defined by ``points``, P is the number of properties associated to each streamline. - - hdr : dict - Header containing meta information about the streamlines. For a list - of common header's fields to use as keys see `nibabel.streamlines.Field`. ''' - def __init__(self, points=[], scalars=[], properties=[]): #, hdr={}): - # Create basic header from given informations. - self._header = {} - self._header[Field.VOXEL_TO_WORLD] = np.eye(4) - - self.points = points - self.scalars = scalars - self.properties = properties - + def __init__(self, points=None, scalars=None, properties=None): + self._header = StreamlinesHeader() + self.points = points + self.scalars = scalars + self.properties = properties @property def header(self): @@ -62,8 +63,8 @@ def points(self): @points.setter def points(self, value): - self._points = value - self._header[Field.NB_STREAMLINES] = len(self.points) + self._points = value if value else [] + self.header.nb_streamlines = len(self.points) @property def scalars(self): @@ -71,10 +72,11 @@ def scalars(self): @scalars.setter def scalars(self, value): - self._scalars = value - self._header[Field.NB_SCALARS_PER_POINT] = 0 - if len(self.scalars) > 0: - self._header[Field.NB_SCALARS_PER_POINT] = len(self.scalars[0]) + self._scalars = value if value else [] + self.header.nb_scalars_per_point = 0 + + if len(self.scalars) > 0 and len(self.scalars[0]) > 0: + self.header.nb_scalars_per_point = len(self.scalars[0][0]) @property def properties(self): @@ -82,10 +84,11 @@ def properties(self): @properties.setter def properties(self, value): - self._properties = value - self._header[Field.NB_PROPERTIES_PER_STREAMLINE] = 0 + self._properties = value if value else [] + self.header.nb_properties_per_streamline = 0 + if len(self.properties) > 0: - self._header[Field.NB_PROPERTIES_PER_STREAMLINE] = len(self.properties[0]) + self.header.nb_properties_per_streamline = len(self.properties[0]) def __iter__(self): return zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) @@ -100,49 +103,89 @@ def __getitem__(self, idx): if len(self.properties) > 0: properties = self.properties[idx] + if type(idx) is slice: + return list(zip_longest(pts, scalars, properties, fillvalue=[])) + return pts, scalars, properties def __len__(self): return len(self.points) + def to_world_space(self, as_generator=False): + affine = self.header.voxel_to_world + new_points = (apply_affine(affine, pts) for pts in self.points) + + if not as_generator: + return list(new_points) + + return new_points + class LazyStreamlines(Streamlines): ''' Class containing information about streamlines. - Streamlines objects have three main properties: ``points``, ``scalars`` - and ``properties``. Streamlines objects can be iterate over producing - tuple of ``points``, ``scalars`` and ``properties`` for each streamline. + Streamlines objects have four main properties: ``header``, ``points``, + ``scalars`` and ``properties``. Streamlines objects are iterable and + produce tuple of ``points``, ``scalars`` and ``properties`` for each + streamline. Parameters ---------- - points : sequence of ndarray of shape (N, 3) - Sequence of T streamlines. One streamline is an ndarray of shape (N, 3) - where N is the number of points in a streamline. - - scalars : sequence of ndarray of shape (N, M) - Sequence of T ndarrays of shape (N, M) where T is the number of - streamlines defined by ``points``, N is the number of points - for a particular streamline and M is the number of scalars - associated to each point (excluding the three coordinates). - - properties : sequence of ndarray of shape (P,) - Sequence of T ndarrays of shape (P,) where T is the number of - streamlines defined by ``points``, P is the number of properties + points_func : coroutine ouputting (N,3) array-like (optional) + Function yielding streamlines' points. One streamline's points is + an array-like of shape (N,3) where N is the number of points in a + streamline. + + scalars_func : coroutine ouputting (N,M) array-like (optional) + Function yielding streamlines' scalars. One streamline's scalars is + an array-like of shape (N,M) where N is the number of points for a + particular streamline and M is the number of scalars associated to + each point (excluding the three coordinates). + + properties_func : coroutine ouputting (P,) array-like (optional) + Function yielding streamlines' properties. One streamline's properties + is an array-like of shape (P,) where P is the number of properties associated to each streamline. - hdr : dict - Header containing meta information about the streamlines. For a list - of common header's fields to use as keys see `nibabel.streamlines.Field`. - ''' - def __init__(self, points=[], scalars=[], properties=[], data=None, count=None, getitem=None): #, hdr={}): - super(LazyStreamlines, self).__init__(points, scalars, properties) + getitem_func : function `idx -> 3-tuples` (optional) + Function returning streamlines (one or a list of 3-tuples) given + an index or a slice (i.e. the __getitem__ function to use). + Notes + ----- + If provided, ``scalars`` and ``properties`` must yield the same number of + values as ``points``. + ''' + def __init__(self, points_func=lambda:[], scalars_func=lambda: [], properties_func=lambda: [], getitem_func=None): + super(LazyStreamlines, self).__init__(points_func, scalars_func, properties_func) self._data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) - if data is not None: - self._data = data if callable(data) else lambda: data + self._getitem = getitem_func - self._count = count - self._getitem = getitem + @classmethod + def create_from_data(cls, data_func): + ''' Saves streamlines to a file-like object. + + Parameters + ---------- + data_func : coroutine ouputting tuple (optional) + Function yielding 3-tuples, (streamline's points, streamline's + scalars, streamline's properties). A streamline's points is an + array-like of shape (N,3), a streamline's scalars is an array-like + of shape (N,M) and streamline's properties is an array-like of + shape (P,) where N is the number of points for a particular + streamline, M is the number of scalars associated to each point + (excluding the three coordinates) and P is the number of properties + associated to each streamline. + ''' + if not callable(data_func): + raise TypeError("`data` must be a coroutine.") + + lazy_streamlines = cls() + lazy_streamlines._data = data_func + lazy_streamlines.points = lambda: (x[0] for x in data_func()) + lazy_streamlines.scalars = lambda: (x[1] for x in data_func()) + lazy_streamlines.properties = lambda: (x[2] for x in data_func()) + return lazy_streamlines @property def points(self): @@ -150,7 +193,10 @@ def points(self): @points.setter def points(self, value): - self._points = value if callable(value) else lambda: value + if not callable(value): + raise TypeError("`points` must be a coroutine.") + + self._points = value @property def scalars(self): @@ -158,7 +204,14 @@ def scalars(self): @scalars.setter def scalars(self, value): - self._scalars = value if callable(value) else lambda: value + if not callable(value): + raise TypeError("`scalars` must be a coroutine.") + + self._scalars = value + self.header.nb_scalars_per_point = 0 + scalars = pop(self.scalars) + if scalars is not None and len(scalars) > 0: + self.header.nb_scalars_per_point = len(scalars[0]) @property def properties(self): @@ -166,7 +219,14 @@ def properties(self): @properties.setter def properties(self, value): - self._properties = value if callable(value) else lambda: value + if not callable(value): + raise TypeError("`properties` must be a coroutine.") + + self._properties = value + self.header.nb_properties_per_streamline = 0 + properties = pop(self.properties) + if properties is not None: + self.header.nb_properties_per_streamline = len(properties) def __getitem__(self, idx): if self._getitem is None: @@ -175,37 +235,28 @@ def __getitem__(self, idx): return self._getitem(idx) def __iter__(self): - return self._data() + i = 0 + for i, s in enumerate(self._data(), start=1): + yield s + + # To be safe, update information about number of streamlines. + self.header.nb_streamlines = i def __len__(self): - # If length is unknown, we'll try to get it as rapidely and accurately as possible. - if self._count is None: - # Length might be contained in the header. - if Field.NB_STREAMLINES in self.header: - return self.header[Field.NB_STREAMLINES] - - if callable(self._count): - # Length might be obtained by re-parsing the file (if streamlines come from one). - self._count = self._count() - - if self._count is None: - try: - # Will work if `points` is a finite sequence (e.g. list, ndarray) - self._count = len(self.points) - except: - pass - - if self._count is None: - # As a last resort, count them by iterating through the list of points (while keeping a copy). + # If length is unknown, we obtain it by iterating through streamlines. + if self.header.nb_streamlines is None: warn("Number of streamlines will be determined manually by looping" - " through the streamlines. Note this will consume any" - " generator used to create this `Streamlines`object. If you" - " know the actual number of streamlines, you might want to" - " set `Field.NB_STREAMLINES` of `self.header` beforehand.") - + " through the streamlines. If you know the actual number of" + " streamlines, you might want to set it beforehand via" + " `self.header.nb_streamlines`." + " Note this will consume any generators used to create this" + " `LazyStreamlines` object.", UsageWarning) return sum(1 for _ in self) - return self._count + return self.header.nb_streamlines + + def to_world_space(self): + return super(LazyStreamlines, self).to_world_space(as_generator=True) class StreamlinesFile: @@ -213,19 +264,29 @@ class StreamlinesFile: @classmethod def get_magic_number(cls): - ''' Return streamlines file's magic number. ''' + ''' Returns streamlines file's magic number. ''' + raise NotImplementedError() + + @classmethod + def can_save_scalars(cls): + ''' Tells if the streamlines format supports saving scalars. ''' + raise NotImplementedError() + + @classmethod + def can_save_properties(cls): + ''' Tells if the streamlines format supports saving properties. ''' raise NotImplementedError() @classmethod def is_correct_format(cls, fileobj): - ''' Check if the file has the right streamlines file format. + ''' Checks if the file has the right streamlines file format. Parameters ---------- fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the - beginning of the header) + beginning of the header). Returns ------- @@ -234,13 +295,8 @@ def is_correct_format(cls, fileobj): ''' raise NotImplementedError() - @classmethod - def get_empty_header(cls): - ''' Return an empty streamlines file's header. ''' - raise NotImplementedError() - - @classmethod - def load(cls, fileobj, lazy_load=True): + @staticmethod + def load(fileobj, ref, lazy_load=True): ''' Loads streamlines from a file-like object. Parameters @@ -248,7 +304,10 @@ def load(cls, fileobj, lazy_load=True): fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the - beginning of the header) + beginning of the header). + + ref : filename | `Nifti1Image` object | 2D array (4,4) + Reference space where streamlines live in `fileobj`. lazy_load : boolean Load streamlines in a lazy manner i.e. they will not be kept @@ -262,8 +321,8 @@ def load(cls, fileobj, lazy_load=True): ''' raise NotImplementedError() - @classmethod - def save(cls, streamlines, fileobj): + @staticmethod + def save(streamlines, fileobj, ref=None): ''' Saves streamlines to a file-like object. Parameters @@ -275,35 +334,38 @@ def save(cls, streamlines, fileobj): fileobj : string or file-like object If string, a filename; otherwise an open file-like object opened and ready to write. + + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines will live in `fileobj`. ''' raise NotImplementedError() @staticmethod def pretty_print(streamlines): - ''' Gets a formatted string contaning header's information - relevant to the streamlines file format. + ''' Gets a formatted string of the header of a streamlines file format. Parameters ---------- - streamlines : Streamlines object - Object containing streamlines' data and header information. - See 'nibabel.Streamlines'. + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header). Returns ------- info : string - Header's information relevant to the streamlines file format. + Header information relevant to the streamlines file format. ''' raise NotImplementedError() -class DynamicStreamlineFile(StreamlinesFile): - ''' Convenience class to encapsulate streamlines file format - that supports appending streamlines to an existing file. - ''' +# class DynamicStreamlineFile(StreamlinesFile): +# ''' Convenience class to encapsulate streamlines file format +# that supports appending streamlines to an existing file. +# ''' - def append(self, streamlines): - raise NotImplementedError() +# def append(self, streamlines): +# raise NotImplementedError() - def __iadd__(self, streamlines): - return self.append(streamlines) +# def __iadd__(self, streamlines): +# return self.append(streamlines) diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 357930c012..46e4af3a94 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -1,7 +1,6 @@ import numpy as np -import nibabel from nibabel.orientations import aff2axcodes -from nibabel.spatialimages import SpatialImage +from collections import OrderedDict class Field: @@ -21,80 +20,92 @@ class Field: ORIGIN = "origin" VOXEL_TO_WORLD = "voxel_to_world" VOXEL_ORDER = "voxel_order" - #WORLD_ORDER = "world_order" ENDIAN = "endian" -def create_header_from_reference(ref): - if type(ref) is np.ndarray: - return create_header_from_affine(ref) - elif isinstance(ref) is SpatialImage: - return create_header_from_nifti(ref) - - # Assume `ref` is a filename: - img = nibabel.load(ref) - return create_header_from_nifti(img) - - -def create_header_from_nifti(img): - ''' Creates a common streamlines' header using a spatial image. - - Based on the information of the nifti image a dictionnary is created - containing the following keys: `Field.ORIGIN`, `Field.DIMENSIONS`, - `Field.VOXEL_SIZES`, `Field.VOXEL_TO_WORLD`, `Field.WORLD_ORDER` - and `Field.VOXEL_ORDER`. - - Parameters - ---------- - img : `SpatialImage` object - Image containing information about the anatomy where streamlines - were created. - - Returns - ------- - hdr : dict - Header containing meta information about streamlines extracted - from the anatomy. - ''' - img_header = img.get_header() - affine = img_header.get_best_affine() - - hdr = {} - - hdr[Field.ORIGIN] = affine[:3, -1] - hdr[Field.DIMENSIONS] = img_header.get_data_shape()[:3] - hdr[Field.VOXEL_SIZES] = img_header.get_zooms()[:3] - hdr[Field.VOXEL_TO_WORLD] = affine - #hdr[Field.WORLD_ORDER] = "RAS" # Nifti space - hdr[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine)) - - return hdr - - -def create_header_from_affine(affine): - ''' Creates a common streamlines' header using an affine matrix. - - Based on the information of the affine matrix a dictionnary is created - containing the following keys: `Field.ORIGIN`, `Field.DIMENSIONS`, - `Field.VOXEL_SIZES`, `Field.VOXEL_TO_WORLD`, `Field.WORLD_ORDER` - and `Field.VOXEL_ORDER`. - - Parameters - ---------- - affine : 2D array (3,3) | 2D array (4,4) - Affine matrix that transforms streamlines from voxel space to world-space. - - Returns - ------- - hdr : dict - Header containing meta information about streamlines. - ''' - hdr = {} - - hdr[Field.ORIGIN] = affine[:3, -1] - hdr[Field.VOXEL_SIZES] = np.sqrt(np.sum(affine[:3, :3]**2, axis=0)) - hdr[Field.VOXEL_TO_WORLD] = affine - #hdr[Field.WORLD_ORDER] = "RAS" # Nifti space - hdr[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine)) - - return hdr +class StreamlinesHeader(object): + def __init__(self): + self._nb_streamlines = None + self._nb_scalars_per_point = None + self._nb_properties_per_streamline = None + self._voxel_to_world = np.eye(4) + self.extra = OrderedDict() + + @property + def voxel_to_world(self): + return self._voxel_to_world + + @voxel_to_world.setter + def voxel_to_world(self, value): + self._voxel_to_world = np.array(value, dtype=np.float32) + + @property + def voxel_sizes(self): + """ Get voxel sizes from voxel_to_world. """ + return np.sqrt(np.sum(self.voxel_to_world[:3, :3]**2, axis=0)) + + @voxel_sizes.setter + def voxel_sizes(self, value): + scaling = np.r_[np.array(value), [1]] + old_scaling = np.r_[np.array(self.voxel_sizes), [1]] + # Remove old scaling and apply new one + self.voxel_to_world = np.dot(np.diag(scaling/old_scaling), self.voxel_to_world) + + @property + def voxel_order(self): + """ Get voxel order from voxel_to_world. """ + return "".join(aff2axcodes(self.voxel_to_world)) + + @property + def nb_streamlines(self): + return self._nb_streamlines + + @nb_streamlines.setter + def nb_streamlines(self, value): + self._nb_streamlines = int(value) + + @property + def nb_scalars_per_point(self): + return self._nb_scalars_per_point + + @nb_scalars_per_point.setter + def nb_scalars_per_point(self, value): + self._nb_scalars_per_point = int(value) + + @property + def nb_properties_per_streamline(self): + return self._nb_properties_per_streamline + + @nb_properties_per_streamline.setter + def nb_properties_per_streamline(self, value): + self._nb_properties_per_streamline = int(value) + + @property + def extra(self): + return self._extra + + @extra.setter + def extra(self, value): + self._extra = OrderedDict(value) + + def __eq__(self, other): + return (np.allclose(self.voxel_to_world, other.voxel_to_world) and + self.nb_streamlines == other.nb_streamlines and + self.nb_scalars_per_point == other.nb_scalars_per_point and + self.nb_properties_per_streamline == other.nb_properties_per_streamline and + repr(self.extra) == repr(other.extra)) # Not the robust way, but will do! + + def __repr__(self): + txt = "Header{\n" + txt += "nb_streamlines: " + repr(self.nb_streamlines) + '\n' + txt += "nb_scalars_per_point: " + repr(self.nb_scalars_per_point) + '\n' + txt += "nb_properties_per_streamline: " + repr(self.nb_properties_per_streamline) + '\n' + txt += "voxel_to_world: " + repr(self.voxel_to_world) + '\n' + txt += "voxel_sizes: " + repr(self.voxel_sizes) + '\n' + + txt += "Extra fields: {\n" + for key in sorted(self.extra.keys()): + txt += " " + repr(key) + ": " + repr(self.extra[key]) + "\n" + + txt += " }\n" + return txt + "}" diff --git a/nibabel/streamlines/tests/data/complex.trk b/nibabel/streamlines/tests/data/complex.trk index 2a96a7bd35f6ce69f05907a351466ac838de1045..9bfbe5ea60917d6f98920b3a08e460be5fb6732d 100644 GIT binary patch literal 1228 zcmWFua&-1)U<5-3h6Z~CW`F}`IBTdgn1;FsklYW7D|A4K7`j=b{J|OmL5{&hISrI= zK`H(PkiiJi0ZcOofLOr+lFuQ6)gpcN-!$gfLvsg`8jw6JO(g*JOaS5yKnwy9aa%5S zhiVZ%2Udt6nqHVXbiFX~3l0nnAAp#{5hMZuAaf;vd<`JBKoj=>@*{va14Z1H%N^(j GKL-HhO(bCe literal 1228 zcmWFua&-1)U<5-3h6Z~CW`F}`IBTdgn1;FsklYW7D|A4K7`j=b{GlEKpnMBT@h^Z3 zM!09dX~Y4N&mn@(Dj03LHZ8>)ja@W21g)<0+6@>kgov52590AKz;xaC!mPia=8QD;O77UQAr@i diff --git a/nibabel/streamlines/tests/data/empty.trk b/nibabel/streamlines/tests/data/empty.trk index 023b3c5905f222fbd76ce9f82ca72be319b427f6..e78e28403b087f627c298afc88c200dc2d50bcff 100644 GIT binary patch delta 18 acmaFC{(^nO7G~xk$KZ*Inv)+ea{vHEWCn-; delta 14 WcmaFC{(^nO7UqctI+GtTa{vG^e+97s diff --git a/nibabel/streamlines/tests/data/simple.trk b/nibabel/streamlines/tests/data/simple.trk index dc5eff4adc9cd919e7bfe31a37485ecc51348109..df601e29a7cb54b7be756c8a58aa205bc19af70a 100644 GIT binary patch literal 1108 zcmWFua&-1)U<5-3h6Z~CW*7y7Is`y*g$^hYLpN)bKUhN`$T65Gr!fOnF#+)lAcGO2 z1DIwG0I`AtNE`^@dR, P->A, I->S). IF (0 based) value [3, 3] from +# coordinates (axes L->R, P->A, I->S). If (0 based) value [3, 3] from # this matrix is 0, this means the matrix is not recorded. header_2_dtd = [(Field.MAGIC_NUMBER, 'S6'), (Field.DIMENSIONS, 'h', 3), @@ -78,132 +80,78 @@ header_2_dtype = np.dtype(header_2_dtd) -class TrkFile(StreamlinesFile): - ''' Convenience class to encapsulate TRK format. ''' - - MAGIC_NUMBER = b"TRACK" - HEADER_SIZE = 1000 - - @classmethod - def get_magic_number(cls): - ''' Return TRK's magic number. ''' - return cls.MAGIC_NUMBER +class TrkReader(object): + ''' Convenience class to encapsulate TRK file format. - @classmethod - def is_correct_format(cls, fileobj): - ''' Check if the file is in TRK format. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to TRK file (and ready to read from the beginning - of the TRK header data) - - Returns - ------- - is_correct_format : boolean - Returns True if `fileobj` is in TRK format. - ''' - with Opener(fileobj) as f: - magic_number = f.read(5) - f.seek(-5, os.SEEK_CUR) - return magic_number == cls.MAGIC_NUMBER - - return False - - @classmethod - def load(cls, fileobj, hdr={}, lazy_load=False): - ''' Loads streamlines from a file-like object. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to TRK file (and ready to read from the beginning - of the TRK header) + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header) - hdr : dict (optional) + Note + ---- + TrackVis (so its file format: TRK) considers the streamline coordinate + (0,0,0) to be in the corner of the voxel whereas NiBabel's streamlines + internal representation (Voxel space) assume (0,0,0) to be in the + center of the voxel. - lazy_load : boolean (optional) - Load streamlines in a lazy manner i.e. they will not be kept - in memory. + Thus, streamlines are shifted of half a voxel on load and are shifted + back on save. + ''' + def __init__(self, fileobj): + self.fileobj = fileobj - Returns - ------- - streamlines : Streamlines object - Returns an object containing streamlines' data and header - information. See 'nibabel.Streamlines'. - ''' - with Opener(fileobj) as f: + with Opener(self.fileobj) as f: # Read header - hdr_str = f.read(header_2_dtype.itemsize) - hdr_rec = np.fromstring(string=hdr_str, dtype=header_2_dtype) + header_str = f.read(header_2_dtype.itemsize) + header_rec = np.fromstring(string=header_str, dtype=header_2_dtype) - if hdr_rec['version'] == 1: - hdr_rec = np.fromstring(string=hdr_str, dtype=header_1_dtype) - elif hdr_rec['version'] == 2: + if header_rec['version'] == 1: + header_rec = np.fromstring(string=header_str, dtype=header_1_dtype) + elif header_rec['version'] == 2: pass # Nothing more to do else: raise HeaderError('NiBabel only supports versions 1 and 2.') - # Convert the first record of `hdr_rec` into a dictionnary - hdr.update(dict(zip(hdr_rec.dtype.names, hdr_rec[0]))) + # Convert the first record of `header_rec` into a dictionnary + self.header = dict(zip(header_rec.dtype.names, header_rec[0])) # Check endianness - hdr[Field.ENDIAN] = native_code - if hdr['hdr_size'] != cls.HEADER_SIZE: - hdr[Field.ENDIAN] = swapped_code - hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0].newbyteorder())) # Swap byte order - if hdr['hdr_size'] != cls.HEADER_SIZE: - raise HeaderError('Invalid hdr_size: {0} instead of {1}'.format(hdr['hdr_size'], cls.HEADER_SIZE)) + self.endianness = native_code + if self.header['hdr_size'] != TrkFile.HEADER_SIZE: + self.endianness = swapped_code + + # Swap byte order + self.header = dict(zip(header_rec.dtype.names, header_rec[0].newbyteorder())) + if self.header['hdr_size'] != TrkFile.HEADER_SIZE: + raise HeaderError('Invalid hdr_size: {0} instead of {1}'.format(self.header['hdr_size'], TrkFile.HEADER_SIZE)) # By default, the voxel order is LPS. # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates - if hdr[Field.VOXEL_ORDER] == "": - warnings.warn("Voxel order is not specified, will assume 'LPS' since it is Trackvis software's default.") - hdr[Field.VOXEL_ORDER] = "LPS" + if self.header[Field.VOXEL_ORDER] == b"": + warnings.warn(("Voxel order is not specified, will assume" + " 'LPS' since it is Trackvis software's" + " default."), HeaderWarning) + self.header[Field.VOXEL_ORDER] = b"LPS" # Keep the file position where the data begin. - hdr['pos_data'] = f.tell() - - # If 'count' field is 0, i.e. not provided, we have to loop until the EOF. - if hdr[Field.NB_STREAMLINES] == 0: - del hdr[Field.NB_STREAMLINES] + self.offset_data = f.tell() - points = lambda: (x[0] for x in TrkFile._read_data(hdr, fileobj)) - scalars = lambda: (x[1] for x in TrkFile._read_data(hdr, fileobj)) - properties = lambda: (x[2] for x in TrkFile._read_data(hdr, fileobj)) - data = lambda: TrkFile._read_data(hdr, fileobj) + def __iter__(self): + i4_dtype = np.dtype(self.endianness + "i4") + f4_dtype = np.dtype(self.endianness + "f4") - if lazy_load: - count = lambda: TrkFile._count(hdr, fileobj) - if Field.NB_STREAMLINES in hdr: - count = hdr[Field.NB_STREAMLINES] - - streamlines = LazyStreamlines(points, scalars, properties, data=data, count=count) - else: - streamlines = Streamlines(*zip(*data())) - - # Set available header's information - streamlines.header.update(hdr) - return streamlines - - @staticmethod - def _read_data(hdr, fileobj): - ''' Read streamlines' data from a file-like object using a TRK's header. ''' - i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") - f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") - - with Opener(fileobj) as f: + with Opener(self.fileobj) as f: start_position = f.tell() - nb_pts_and_scalars = 3 + int(hdr[Field.NB_SCALARS_PER_POINT]) - pts_and_scalars_size = nb_pts_and_scalars * f4_dtype.itemsize + nb_pts_and_scalars = int(3 + self.header[Field.NB_SCALARS_PER_POINT]) + pts_and_scalars_size = int(nb_pts_and_scalars * f4_dtype.itemsize) slice_pts_and_scalars = lambda data: (data, []) - if hdr[Field.NB_SCALARS_PER_POINT] > 0: - # This is faster than np.split + if self.header[Field.NB_SCALARS_PER_POINT] > 0: + # This is faster than `np.split` slice_pts_and_scalars = lambda data: (data[:, :3], data[:, 3:]) # Using np.fromfile would be faster, but does not support StringIO @@ -211,18 +159,21 @@ def _read_data(hdr, fileobj): dtype=f4_dtype, buffer=f.read(nb_pts * pts_and_scalars_size))) - properties_size = int(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) * f4_dtype.itemsize + properties_size = int(self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize) read_properties = lambda: [] - if hdr[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + if self.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: read_properties = lambda: np.fromstring(f.read(properties_size), dtype=f4_dtype, - count=hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + count=self.header[Field.NB_PROPERTIES_PER_STREAMLINE]) # Set the file position at the beginning of the data. - f.seek(hdr['pos_data'], os.SEEK_SET) + f.seek(self.offset_data, os.SEEK_SET) + + # If 'count' field is 0, i.e. not provided, we have to loop until the EOF. + nb_streamlines = self.header[Field.NB_STREAMLINES] + if nb_streamlines == 0: + nb_streamlines = np.inf - #for i in xrange(hdr[Field.NB_STREAMLINES]): - nb_streamlines = hdr.get(Field.NB_STREAMLINES, np.inf) i = 0 while i < nb_streamlines: nb_pts_str = f.read(i4_dtype.itemsize) @@ -239,163 +190,263 @@ def _read_data(hdr, fileobj): properties = read_properties() # TRK's streamlines are in 'voxelmm' space, we send them to voxel space. - pts = pts / hdr[Field.VOXEL_SIZES] + pts = pts / self.header[Field.VOXEL_SIZES] + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas streamlines returned assume (0,0,0) to be the + # center of the voxel. Thus, streamlines are shifted of half + #a voxel. + pts -= np.array(self.header[Field.VOXEL_SIZES])/2. yield pts, scalars, properties i += 1 # In case the 'count' field was not provided. - hdr[Field.NB_STREAMLINES] = i + self.header[Field.NB_STREAMLINES] = i - # Set the file position where it was. + # Set the file position where it was (in case it was already open). f.seek(start_position, os.SEEK_CUR) - @staticmethod - def _count(hdr, fileobj): - ''' Count streamlines from a file-like object using a TRK's header. ''' - nb_streamlines = 0 - with Opener(fileobj) as f: - start_position = f.tell() +class TrkWriter(object): + @classmethod + def create_empty_header(cls): + ''' Return an empty compliant TRK header. ''' + header = np.zeros(1, dtype=header_2_dtype) - i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") - f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") + #Default values + header[Field.MAGIC_NUMBER] = TrkFile.MAGIC_NUMBER + header[Field.VOXEL_SIZES] = (1, 1, 1) + header[Field.DIMENSIONS] = (1, 1, 1) + header[Field.VOXEL_TO_WORLD] = np.eye(4) + header['version'] = 2 + header['hdr_size'] = TrkFile.HEADER_SIZE - pts_and_scalars_size = (3 + hdr[Field.NB_SCALARS_PER_POINT]) * f4_dtype.itemsize - properties_size = hdr[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize + return header - # Set the file position at the beginning of the data. - f.seek(hdr['pos_data'], os.SEEK_SET) + def __init__(self, fileobj, header): + self.header = self.create_empty_header() - # Count the actual number of streamlines. - while True: - # Read number of points of the streamline - buf = f.read(i4_dtype.itemsize) + # Override hdr's fields by those contain in `header`. + for k, v in header.extra.items(): + if k in header_2_dtype.fields.keys(): + self.header[k] = v - if len(buf) == 0: - break # EOF + self.header[Field.NB_STREAMLINES] = 0 + if header.nb_streamlines is not None: + self.header[Field.NB_STREAMLINES] = header.nb_streamlines - nb_pts = struct.unpack(i4_dtype.str[:-1], buf)[0] - bytes_to_skip = nb_pts * pts_and_scalars_size - bytes_to_skip += properties_size + self.header[Field.NB_SCALARS_PER_POINT] = header.nb_scalars_per_point + self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = header.nb_properties_per_streamline + self.header[Field.VOXEL_SIZES] = header.voxel_sizes + self.header[Field.VOXEL_TO_WORLD] = header.voxel_to_world + self.header[Field.VOXEL_ORDER] = header.voxel_order - # Seek to the next streamline in the file. - f.seek(bytes_to_skip, os.SEEK_CUR) + # Keep counts for correcting incoherent fields or warn. + self.nb_streamlines = 0 + self.nb_points = 0 + self.nb_scalars = 0 + self.nb_properties = 0 + + # Write header + self.file = Opener(fileobj, mode="wb") + # Keep track of the beginning of the header. + self.beginning = self.file.tell() + self.file.write(self.header[0].tostring()) + + def write(self, streamlines): + i4_dtype = np.dtype("i4") + f4_dtype = np.dtype("f4") + + for points, scalars, properties in streamlines: + if len(scalars) > 0 and len(scalars) != len(points): + raise DataError("Missing scalars for some points!") + + points = np.array(points, dtype=f4_dtype) + scalars = np.array(scalars, dtype=f4_dtype).reshape((len(points), -1)) + properties = np.array(properties, dtype=f4_dtype) + + # TRK's streamlines need to be in 'voxelmm' space + points = points * self.header[Field.VOXEL_SIZES] + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas streamlines passed in parameters assume (0,0,0) + # to be the center of the voxel. Thus, streamlines are shifted of + # half a voxel. + points += np.array(self.header[Field.VOXEL_SIZES])/2. + + data = struct.pack(i4_dtype.str[:-1], len(points)) + data += np.concatenate((points, scalars), axis=1).tostring() + data += properties.tostring() + self.file.write(data) + + self.nb_streamlines += 1 + self.nb_points += len(points) + self.nb_scalars += scalars.size + self.nb_properties += len(properties) + + # Either correct or warn if header and data are incoherent. + #TODO: add a warn option as a function parameter + nb_scalars_per_point = self.nb_scalars / self.nb_points + nb_properties_per_streamline = self.nb_properties / self.nb_streamlines + + # Check for errors + if nb_scalars_per_point != int(nb_scalars_per_point): + raise DataError("Nb. of scalars differs from one point to another!") + + if nb_properties_per_streamline != int(nb_properties_per_streamline): + raise DataError("Nb. of properties differs from one streamline to another!") + + self.header[Field.NB_STREAMLINES] = self.nb_streamlines + self.header[Field.NB_SCALARS_PER_POINT] = nb_scalars_per_point + self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = nb_properties_per_streamline + + # Overwrite header with updated one. + self.file.seek(self.beginning, os.SEEK_SET) + self.file.write(self.header[0].tostring()) - nb_streamlines += 1 - f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. +class TrkFile(StreamlinesFile): + ''' Convenience class to encapsulate TRK file format. - return nb_streamlines + Note + ---- + TrackVis (so its file format: TRK) considers the streamline coordinate + (0,0,0) to be in the corner of the voxel whereas NiBabel's streamlines + internal representation (Voxel space) assume (0,0,0) to be in the + center of the voxel. + + Thus, streamlines are shifted of half a voxel on load and are shifted + back on save. + ''' + + # Contants + MAGIC_NUMBER = b"TRACK" + HEADER_SIZE = 1000 @classmethod - def create_empty_header(cls): - ''' Return an empty TRK compliant header. ''' - hdr = np.zeros(1, dtype=header_2_dtype) + def get_magic_number(cls): + ''' Return TRK's magic number. ''' + return cls.MAGIC_NUMBER - #Default values - hdr[Field.MAGIC_NUMBER] = cls.MAGIC_NUMBER - hdr[Field.VOXEL_SIZES] = (1, 1, 1) - hdr[Field.DIMENSIONS] = (1, 1, 1) - hdr[Field.VOXEL_TO_WORLD] = np.eye(4) - hdr['version'] = 2 - hdr['hdr_size'] = cls.HEADER_SIZE + @classmethod + def can_save_scalars(cls): + ''' Tells if the streamlines format supports saving scalars. ''' + return True - return hdr + @classmethod + def can_save_properties(cls): + ''' Tells if the streamlines format supports saving properties. ''' + return True @classmethod - def save(cls, streamlines, fileobj): - ''' Saves streamlines to a file-like object. + def is_correct_format(cls, fileobj): + ''' Check if the file is in TRK format. Parameters ---------- - streamlines : Streamlines object - Object containing streamlines' data and header information. - See 'nibabel.Streamlines'. - fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to TRK file (and ready to read from the beginning - of the TRK header data) - - hdr : dict (optional) + of the TRK header data). - Notes - ----- - Streamlines are assumed to be in voxel space. + Returns + ------- + is_correct_format : boolean + Returns True if `fileobj` is in TRK format. ''' - hdr = cls.create_empty_header() - - #Override hdr's fields by those contain in `streamlines`'s header - for k, v in streamlines.header.items(): - if k in header_2_dtype.fields.keys(): - hdr[k] = v + with Opener(fileobj) as f: + magic_number = f.read(5) + f.seek(-5, os.SEEK_CUR) + return magic_number == cls.MAGIC_NUMBER - # Check which endianess to use to write data. - endianess = streamlines.header.get(Field.ENDIAN, native_code) + return False - if endianess == swapped_code: - hdr = hdr.newbyteorder() + @staticmethod + def load(fileobj, ref=None, lazy_load=False): + ''' Loads streamlines from a file-like object. - i4_dtype = np.dtype(endianess + "i4") - f4_dtype = np.dtype(endianess + "f4") + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header). - # Keep counts for correcting incoherent fields or warn. - nb_streamlines = 0 - nb_points = 0 - nb_scalars = 0 - nb_properties = 0 + ref : filename | `Nifti1Image` object | 2D array (4,4) | None + Reference space where streamlines live in `fileobj`. - # Write header + data of streamlines - with Opener(fileobj, mode="wb") as f: - pos = f.tell() - # Write header - f.write(hdr[0].tostring()) + lazy_load : boolean (optional) + Load streamlines in a lazy manner i.e. they will not be kept + in memory. - for points, scalars, properties in streamlines: - if len(scalars) > 0 and len(scalars) != len(points): - raise DataError("Missing scalars for some points!") + Returns + ------- + streamlines : Streamlines object + Returns an object containing streamlines' data and header + information. See `nibabel.Streamlines`. - points = np.array(points, dtype=f4_dtype) - scalars = np.array(scalars, dtype=f4_dtype).reshape((len(points), -1)) - properties = np.array(properties, dtype=f4_dtype) + Notes + ----- + Streamlines are assumed to be in voxel space where coordinate (0,0,0) + refers to the center of the voxel. + ''' + trk_reader = TrkReader(fileobj) + + # Check if reference space matches one from TRK's header. + affine = trk_reader.header[Field.VOXEL_TO_WORLD] + if ref is not None: + affine = get_affine_from_reference(ref) + if not np.allclose(affine, trk_reader.header[Field.VOXEL_TO_WORLD]): + raise ValueError("Reference space provided does not match the " + " one from the TRK file header. Use `ref=None`" + " to use one contained in the TRK file") + + #points = lambda: (x[0] for x in trk_reader) + #scalars = lambda: (x[1] for x in trk_reader) + #properties = lambda: (x[2] for x in trk_reader) + data = lambda: iter(trk_reader) - # TRK's streamlines need to be in 'voxelmm' space - points = points * hdr[Field.VOXEL_SIZES] + if lazy_load: + streamlines = LazyStreamlines.create_from_data(data) - data = struct.pack(i4_dtype.str[:-1], len(points)) - data += np.concatenate((points, scalars), axis=1).tostring() - data += properties.tostring() - f.write(data) + # Overwrite scalars and properties if there is none + if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: + streamlines.scalars = lambda: [] + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: + streamlines.properties = lambda: [] + else: + streamlines = Streamlines(*zip(*data())) - nb_streamlines += 1 - nb_points += len(points) - nb_scalars += scalars.size - nb_properties += len(properties) + # Overwrite scalars and properties if there is none + if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: + streamlines.scalars = [] + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: + streamlines.properties = [] - # Either correct or warn if header and data are incoherent. - #TODO: add a warn option as a function parameter - nb_scalars_per_point = nb_scalars / nb_points - nb_properties_per_streamline = nb_properties / nb_streamlines + # Set available common information about streamlines in the header + streamlines.header.voxel_to_world = affine - # Check for errors - if nb_scalars_per_point != int(nb_scalars_per_point): - raise DataError("Nb. of scalars differs from one point to another!") + # If 'count' field is 0, i.e. not provided, we don't set `nb_streamlines` + if trk_reader.header[Field.NB_STREAMLINES] > 0: + streamlines.header.nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] - if nb_properties_per_streamline != int(nb_properties_per_streamline): - raise DataError("Nb. of properties differs from one streamline to another!") + # Keep extra information about TRK format + streamlines.header.extra = trk_reader.header - hdr[Field.NB_STREAMLINES] = nb_streamlines - hdr[Field.NB_SCALARS_PER_POINT] = nb_scalars_per_point - hdr[Field.NB_PROPERTIES_PER_STREAMLINE] = nb_properties_per_streamline + ## Perform some integrity checks + #if trk_reader.header[Field.VOXEL_ORDER] != streamlines.header.voxel_order: + # raise HeaderError("'voxel_order' does not match the affine.") + #if streamlines.header.voxel_sizes != trk_reader.header[Field.VOXEL_SIZES]: + # raise HeaderError("'voxel_sizes' does not match the affine.") + #if streamlines.header.nb_scalars_per_point != trk_reader.header[Field.NB_SCALARS_PER_POINT]: + # raise HeaderError("'nb_scalars_per_point' does not match.") + #if streamlines.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: + # raise HeaderError("'nb_properties_per_streamline' does not match.") - f.seek(pos, os.SEEK_SET) - f.write(hdr[0].tostring()) # Overwrite header with updated one. + return streamlines @staticmethod - def pretty_print(streamlines): - ''' Gets a formatted string contaning header's information - relevant to the TRK format. + def save(streamlines, fileobj, ref=None): + ''' Saves streamlines to a file-like object. Parameters ---------- @@ -403,12 +454,43 @@ def pretty_print(streamlines): Object containing streamlines' data and header information. See 'nibabel.Streamlines'. + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data). + + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines will live in `fileobj`. + + Notes + ----- + Streamlines are assumed to be in voxel space where coordinate (0,0,0) + refers to the center of the voxel. + ''' + if ref is not None: + streamlines.header.voxel_to_world = get_affine_from_reference(ref) + + trk_writer = TrkWriter(fileobj, streamlines.header) + trk_writer.write(streamlines) + + @staticmethod + def pretty_print(fileobj): + ''' Gets a formatted string of the header of a TRK file. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the header). + Returns ------- info : string - Header's information relevant to the TRK format. + Header information relevant to the TRK format. ''' - hdr = streamlines.header + trk_reader = TrkReader(fileobj) + hdr = trk_reader.header info = "" info += "MAGIC NUMBER: {0}".format(hdr[Field.MAGIC_NUMBER]) @@ -421,7 +503,6 @@ def pretty_print(streamlines): info += "nb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) info += "property_name:\n {0}".format("\n".join(hdr['property_name'])) info += "vox_to_world: {0}".format(hdr[Field.VOXEL_TO_WORLD]) - #info += "world_order: {0}".format(hdr[Field.WORLD_ORDER]) info += "voxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) info += "image_orientation_patient: {0}".format(hdr['image_orientation_patient']) info += "pad1: {0}".format(hdr['pad1']) diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 0afc7bccda..3b5f3ecf56 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -1,151 +1,38 @@ -import os +import numpy as np +import nibabel +import itertools -from ..externals.six import string_types +from nibabel.spatialimages import SpatialImage -from nibabel.streamlines import header -from nibabel.streamlines.base_format import LazyStreamlines -from nibabel.streamlines.trk import TrkFile -#from nibabel.streamlines.tck import TckFile -#from nibabel.streamlines.vtk import VtkFile +def get_affine_from_reference(ref): + """ Returns the affine defining the reference space. -# List of all supported formats -FORMATS = {".trk": TrkFile, - #".tck": TckFile, - #".vtk": VtkFile, - } - - -def is_supported(fileobj): - ''' Checks if the file-like object if supported by NiBabel. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object pointing - to a streamlines file (and ready to read from the beginning of the - header) - - Returns - ------- - is_supported : boolean - ''' - return detect_format(fileobj) is not None - - -def detect_format(fileobj): - ''' Returns the StreamlinesFile object guessed from the file-like object. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object pointing - to a streamlines file (and ready to read from the beginning of the - header) + Parameter + --------- + ref : filename | `Nifti1Image` object | 2D array (4,4) + Reference space where streamlines live in `fileobj`. Returns ------- - streamlines_file : StreamlinesFile object - Object that can be used to manage a streamlines file. - See 'nibabel.streamlines.StreamlinesFile'. - ''' - for format in FORMATS.values(): - try: - if format.is_correct_format(fileobj): - return format - - except IOError: - pass - - if isinstance(fileobj, string_types): - _, ext = os.path.splitext(fileobj) - return FORMATS.get(ext, None) - - return None - - -def load(fileobj, ref, lazy_load=False): - ''' Loads streamlines from a file-like object in voxel space. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the beginning - of the streamlines file's header) - - ref : filename | `Nifti1Image` object | 2D array (3,3) | 2D array (4,4) | None - Reference space where streamlines have been created. - - lazy_load : boolean (optional) - Load streamlines in a lazy manner i.e. they will not be kept - in memory. - - Returns - ------- - obj : instance of ``Streamlines`` - Returns an instance of a ``Streamlines`` class containing data and metadata - of streamlines loaded from ``fileobj``. - ''' - streamlines_file = detect_format(fileobj) - - if streamlines_file is None: - raise TypeError("Unknown format for 'fileobj': {0}!".format(fileobj)) - - hdr = {} - if ref is not None: - hdr = header.create_header_from_reference(ref) - - return streamlines_file.load(fileobj, hdr=hdr, lazy_load=lazy_load) - - -def save(streamlines, filename, ref=None): - ''' Saves a ``Streamlines`` object to a file - - Parameters - ---------- - streamlines : Streamlines object - Streamlines to be saved (metadata is obtained with the function ``get_header`` of ``streamlines``). - - filename : string - Name of the file where the streamlines will be saved. The format will be guessed from ``filename``. - - ref : filename | `Nifti1Image` object | 2D array (3,3) | 2D array (4,4) | None (optional) - Reference space the streamlines belong to. Default: get ref from `streamlines.header`. - ''' - streamlines_file = detect_format(filename) - - if streamlines_file is None: - raise TypeError("Unknown format for 'filename': {0}!".format(filename)) - - if ref is not None: - # Create a `LazyStreamlines` from `streamlines` but using the new reference image. - streamlines = LazyStreamlines(data=iter(streamlines)) - streamlines.header.update(streamlines.header) - streamlines.header.update(header.create_header_from_reference(ref)) - - streamlines_file.save(streamlines, filename) - - -def convert(in_fileobj, out_filename): - ''' Converts one streamlines format to another. + affine : 2D array (4,4) + """ + if type(ref) is np.ndarray: + if ref.shape != (4, 4): + raise ValueError("`ref` needs to be a numpy array with shape (4,4)!") - It does not change the space in which the streamlines are. + return ref + elif isinstance(ref, SpatialImage): + return ref.affine - Parameters - ---------- - in_fileobj : string or file-like object - If string, a filename; otherwise an open file-like object pointing - to a streamlines file (and ready to read from the beginning of the - header) + # Assume `ref` is the name of a neuroimaging file. + return nibabel.load(ref).affine - out_filename : string - Name of the file where the streamlines will be saved. The format will - be guessed from ``out_filename``. - ''' - streamlines = load(in_fileobj, lazy_load=True) - save(streamlines, out_filename) +def pop(iterable): + "Returns the next item from the iterable else None" + value = list(itertools.islice(iterable, 1)) + return value[0] if len(value) > 0 else None # TODO def change_space(streamline_file, new_point_space): diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index 28dcf2ede2..6abd9ee20d 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -9,7 +9,7 @@ ''' Utilities for testing ''' from os.path import dirname, abspath, join as pjoin from warnings import catch_warnings -from nibabel.externals.six.moves import zip +from nibabel.externals.six.moves import zip, zip_longest from numpy.testing import assert_array_equal @@ -61,17 +61,16 @@ def assert_allclose_safely(a, b, match_nans=True): def assert_arrays_equal(arrays1, arrays2): - for arr1, arr2 in zip(arrays1, arrays2): + for arr1, arr2 in zip_longest(arrays1, arrays2, fillvalue=None): assert_array_equal(arr1, arr2) def assert_streamlines_equal(s1, s2): - assert_equal(s1.get_header(), s2.get_header()) - - for (points1, scalars1, properties1), (points2, scalars2, properties2) in zip(s1, s2): - assert_array_equal(points1, points2) - assert_array_equal(scalars1, scalars2) - assert_array_equal(properties1, properties2) + assert_equal(s1.header, s2.header) + assert_equal(len(s1), len(s2)) + assert_arrays_equal(s1.points, s2.points) + assert_arrays_equal(s1.scalars, s2.scalars) + assert_arrays_equal(s1.properties, s2.properties) class suppress_warnings(catch_warnings): From 3f43b056a7fc0e047c9e8d4673b12091dd3cf5be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 17 Apr 2015 00:53:54 -0400 Subject: [PATCH 07/54] Fixed import of OrderedDict and added support to apply transformation on streamlines. --- nibabel/streamlines/base_format.py | 85 +++++++++++++++++-- nibabel/streamlines/header.py | 40 +++++---- nibabel/streamlines/tests/test_base_format.py | 18 ++-- nibabel/streamlines/tests/test_header.py | 12 +-- nibabel/streamlines/tests/test_streamlines.py | 16 ++-- nibabel/streamlines/trk.py | 16 ++-- nibabel/streamlines/utils.py | 4 - 7 files changed, 134 insertions(+), 57 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index a94f5728ad..e5068c69f8 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -1,3 +1,4 @@ +import numpy as np from warnings import warn from nibabel.externals.six.moves import zip_longest @@ -111,14 +112,56 @@ def __getitem__(self, idx): def __len__(self): return len(self.points) - def to_world_space(self, as_generator=False): - affine = self.header.voxel_to_world - new_points = (apply_affine(affine, pts) for pts in self.points) + def copy(self): + """ Returns a copy of this `Streamlines` object. """ + streamlines = Streamlines(self.points, self.scalars, self.properties) + streamlines._header = self.header.copy() + return streamlines - if not as_generator: - return list(new_points) + def transform(self, affine, lazy=False): + """ Applies an affine transformation on the points of each streamline. - return new_points + Parameters + ---------- + affine : 2D array (4,4) + Transformation that will be applied on each streamline. + lazy : bool (optional) + If true output will be a generator of arrays instead of a list. + + Returns + ------- + streamlines + If `lazy` is true, a `LazyStreamlines` object is returned, + otherwise a `Streamlines` object is returned. In both case, + streamlines are in a space defined by `affine`. + """ + points = lambda: (apply_affine(affine, pts) for pts in self.points) + if not lazy: + points = list(points()) + + streamlines = self.copy() + streamlines.points = points + streamlines.header.to_world_space = np.dot(streamlines.header.to_world_space, + np.linalg.inv(affine)) + + return streamlines + + def to_world_space(self, lazy=False): + """ Sends the streamlines back into world space. + + Parameters + ---------- + lazy : bool (optional) + If true output will be a generator of arrays instead of a list. + + Returns + ------- + streamlines + If `lazy` is true, a `LazyStreamlines` object is returned, + otherwise a `Streamlines` object is returned. In both case, + streamlines are in world space. + """ + return self.transform(self.header.to_world_space, lazy) class LazyStreamlines(Streamlines): @@ -255,8 +298,36 @@ def __len__(self): return self.header.nb_streamlines + def copy(self): + """ Returns a copy of this `LazyStreamlines` object. """ + streamlines = LazyStreamlines(self._points, self._scalars, self._properties) + streamlines._header = self.header.copy() + return streamlines + + def transform(self, affine): + """ Applies an affine transformation on the points of each streamline. + + Parameters + ---------- + affine : 2D array (4,4) + Transformation that will be applied on each streamline. + + Returns + ------- + streamlines : `LazyStreamlines` object + Streamlines living in a space defined by `affine`. + """ + return super(LazyStreamlines, self).transform(affine, lazy=True) + def to_world_space(self): - return super(LazyStreamlines, self).to_world_space(as_generator=True) + """ Sends the streamlines back into world space. + + Returns + ------- + streamlines : `LazyStreamlines` object + Streamlines living in world space. + """ + return super(LazyStreamlines, self).to_world_space(lazy=True) class StreamlinesFile: diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 46e4af3a94..7a7ec63b3c 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -1,6 +1,7 @@ +import copy import numpy as np from nibabel.orientations import aff2axcodes -from collections import OrderedDict +from nibabel.externals import OrderedDict class Field: @@ -18,7 +19,7 @@ class Field: DIMENSIONS = "dimensions" MAGIC_NUMBER = "magic_number" ORIGIN = "origin" - VOXEL_TO_WORLD = "voxel_to_world" + to_world_space = "to_world_space" VOXEL_ORDER = "voxel_order" ENDIAN = "endian" @@ -28,33 +29,33 @@ def __init__(self): self._nb_streamlines = None self._nb_scalars_per_point = None self._nb_properties_per_streamline = None - self._voxel_to_world = np.eye(4) + self._to_world_space = np.eye(4) self.extra = OrderedDict() @property - def voxel_to_world(self): - return self._voxel_to_world + def to_world_space(self): + return self._to_world_space - @voxel_to_world.setter - def voxel_to_world(self, value): - self._voxel_to_world = np.array(value, dtype=np.float32) + @to_world_space.setter + def to_world_space(self, value): + self._to_world_space = np.array(value, dtype=np.float32) @property def voxel_sizes(self): - """ Get voxel sizes from voxel_to_world. """ - return np.sqrt(np.sum(self.voxel_to_world[:3, :3]**2, axis=0)) + """ Get voxel sizes from to_world_space. """ + return np.sqrt(np.sum(self.to_world_space[:3, :3]**2, axis=0)) @voxel_sizes.setter def voxel_sizes(self, value): scaling = np.r_[np.array(value), [1]] old_scaling = np.r_[np.array(self.voxel_sizes), [1]] # Remove old scaling and apply new one - self.voxel_to_world = np.dot(np.diag(scaling/old_scaling), self.voxel_to_world) + self.to_world_space = np.dot(np.diag(scaling/old_scaling), self.to_world_space) @property def voxel_order(self): - """ Get voxel order from voxel_to_world. """ - return "".join(aff2axcodes(self.voxel_to_world)) + """ Get voxel order from to_world_space. """ + return "".join(aff2axcodes(self.to_world_space)) @property def nb_streamlines(self): @@ -88,8 +89,17 @@ def extra(self): def extra(self, value): self._extra = OrderedDict(value) + def copy(self): + header = StreamlinesHeader() + header._nb_streamlines = self.nb_streamlines + header.nb_scalars_per_point = self.nb_scalars_per_point + header.nb_properties_per_streamline = self.nb_properties_per_streamline + header.to_world_space = self.to_world_space.copy() + header.extra = copy.deepcopy(self.extra) + return header + def __eq__(self, other): - return (np.allclose(self.voxel_to_world, other.voxel_to_world) and + return (np.allclose(self.to_world_space, other.to_world_space) and self.nb_streamlines == other.nb_streamlines and self.nb_scalars_per_point == other.nb_scalars_per_point and self.nb_properties_per_streamline == other.nb_properties_per_streamline and @@ -100,7 +110,7 @@ def __repr__(self): txt += "nb_streamlines: " + repr(self.nb_streamlines) + '\n' txt += "nb_scalars_per_point: " + repr(self.nb_scalars_per_point) + '\n' txt += "nb_properties_per_streamline: " + repr(self.nb_properties_per_streamline) + '\n' - txt += "voxel_to_world: " + repr(self.voxel_to_world) + '\n' + txt += "to_world_space: " + repr(self.to_world_space) + '\n' txt += "voxel_sizes: " + repr(self.voxel_sizes) + '\n' txt += "Extra fields: {\n" diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index e132ccf576..f2396da411 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -117,8 +117,8 @@ def test_to_world_space(self): # World space is (RAS+) with voxel size of 2x3x4mm. streamlines.header.voxel_sizes = (2, 3, 4) - new_points = streamlines.to_world_space() - for new_pts, pts in zip(new_points, self.points): + new_streamlines = streamlines.to_world_space() + for new_pts, pts in zip(new_streamlines.points, self.points): for dim, size in enumerate(streamlines.header.voxel_sizes): assert_array_almost_equal(new_pts[:, dim], size*pts[:, dim]) @@ -129,7 +129,7 @@ def test_header(self): assert_equal(streamlines.header.nb_scalars_per_point, 0) assert_equal(streamlines.header.nb_properties_per_streamline, 0) assert_array_equal(streamlines.header.voxel_sizes, (1, 1, 1)) - assert_array_equal(streamlines.header.voxel_to_world, np.eye(4)) + assert_array_equal(streamlines.header.to_world_space, np.eye(4)) assert_equal(streamlines.header.extra, {}) streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) @@ -138,15 +138,15 @@ def test_header(self): assert_equal(streamlines.header.nb_scalars_per_point, self.colors[0].shape[1]) assert_equal(streamlines.header.nb_properties_per_streamline, self.mean_curvature_torsion[0].shape[0]) - # Modifying voxel_sizes should be reflected in voxel_to_world + # Modifying voxel_sizes should be reflected in to_world_space streamlines.header.voxel_sizes = (2, 3, 4) assert_array_equal(streamlines.header.voxel_sizes, (2, 3, 4)) - assert_array_equal(np.diag(streamlines.header.voxel_to_world), (2, 3, 4, 1)) + assert_array_equal(np.diag(streamlines.header.to_world_space), (2, 3, 4, 1)) - # Modifying scaling of voxel_to_world should be reflected in voxel_sizes - streamlines.header.voxel_to_world = np.diag([4, 3, 2, 1]) + # Modifying scaling of to_world_space should be reflected in voxel_sizes + streamlines.header.to_world_space = np.diag([4, 3, 2, 1]) assert_array_equal(streamlines.header.voxel_sizes, (4, 3, 2)) - assert_array_equal(streamlines.header.voxel_to_world, np.diag([4, 3, 2, 1])) + assert_array_equal(streamlines.header.to_world_space, np.diag([4, 3, 2, 1])) # Test that we can run __repr__ without error. repr(streamlines.header) @@ -313,7 +313,7 @@ def test_lazy_streamlines_header(self): assert_equal(streamlines.header.nb_scalars_per_point, 0) assert_equal(streamlines.header.nb_properties_per_streamline, 0) assert_array_equal(streamlines.header.voxel_sizes, (1, 1, 1)) - assert_array_equal(streamlines.header.voxel_to_world, np.eye(4)) + assert_array_equal(streamlines.header.to_world_space, np.eye(4)) assert_equal(streamlines.header.extra, {}) points = lambda: (x for x in self.points) diff --git a/nibabel/streamlines/tests/test_header.py b/nibabel/streamlines/tests/test_header.py index 99184bb2ba..a36a257818 100644 --- a/nibabel/streamlines/tests/test_header.py +++ b/nibabel/streamlines/tests/test_header.py @@ -12,7 +12,7 @@ def test_streamlines_header(): assert_true(header.nb_scalars_per_point is None) assert_true(header.nb_properties_per_streamline is None) assert_array_equal(header.voxel_sizes, (1, 1, 1)) - assert_array_equal(header.voxel_to_world, np.eye(4)) + assert_array_equal(header.to_world_space, np.eye(4)) assert_equal(header.extra, {}) # Modify simple attributes @@ -23,15 +23,15 @@ def test_streamlines_header(): assert_equal(header.nb_scalars_per_point, 2) assert_equal(header.nb_properties_per_streamline, 3) - # Modifying voxel_sizes should be reflected in voxel_to_world + # Modifying voxel_sizes should be reflected in to_world_space header.voxel_sizes = (2, 3, 4) assert_array_equal(header.voxel_sizes, (2, 3, 4)) - assert_array_equal(np.diag(header.voxel_to_world), (2, 3, 4, 1)) + assert_array_equal(np.diag(header.to_world_space), (2, 3, 4, 1)) - # Modifying scaling of voxel_to_world should be reflected in voxel_sizes - header.voxel_to_world = np.diag([4, 3, 2, 1]) + # Modifying scaling of to_world_space should be reflected in voxel_sizes + header.to_world_space = np.diag([4, 3, 2, 1]) assert_array_equal(header.voxel_sizes, (4, 3, 2)) - assert_array_equal(header.voxel_to_world, np.diag([4, 3, 2, 1])) + assert_array_equal(header.to_world_space, np.diag([4, 3, 2, 1])) # Test that we can run __repr__ without error. repr(header) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index f137559631..c293ece3b1 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -140,12 +140,12 @@ def setUp(self): self.nb_streamlines = len(self.points) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) - self.voxel_to_world = np.eye(4) + self.to_world_space = np.eye(4) def test_load_empty_file(self): for empty_filename in self.empty_filenames: streamlines = nib.streamlines.load(empty_filename, - ref=self.voxel_to_world, + ref=self.to_world_space, lazy_load=False) assert_true(type(streamlines), Streamlines) check_streamlines(streamlines, 0, [], [], []) @@ -153,7 +153,7 @@ def test_load_empty_file(self): def test_load_simple_file(self): for simple_filename in self.simple_filenames: streamlines = nib.streamlines.load(simple_filename, - ref=self.voxel_to_world, + ref=self.to_world_space, lazy_load=False) assert_true(type(streamlines), Streamlines) check_streamlines(streamlines, self.nb_streamlines, @@ -161,7 +161,7 @@ def test_load_simple_file(self): # Test lazy_load streamlines = nib.streamlines.load(simple_filename, - ref=self.voxel_to_world, + ref=self.to_world_space, lazy_load=True) assert_true(type(streamlines), LazyStreamlines) check_streamlines(streamlines, self.nb_streamlines, @@ -180,7 +180,7 @@ def test_load_complex_file(self): properties = self.mean_curvature_torsion streamlines = nib.streamlines.load(complex_filename, - ref=self.voxel_to_world, + ref=self.to_world_space, lazy_load=False) assert_true(type(streamlines), Streamlines) check_streamlines(streamlines, self.nb_streamlines, @@ -188,7 +188,7 @@ def test_load_complex_file(self): # Test lazy_load streamlines = nib.streamlines.load(complex_filename, - ref=self.voxel_to_world, + ref=self.to_world_space, lazy_load=True) assert_true(type(streamlines), LazyStreamlines) check_streamlines(streamlines, self.nb_streamlines, @@ -199,7 +199,7 @@ def test_save_simple_file(self): for ext in nib.streamlines.FORMATS.keys(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: nib.streamlines.save(streamlines, f.name) - loaded_streamlines = nib.streamlines.load(f, ref=self.voxel_to_world, lazy_load=False) + loaded_streamlines = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) check_streamlines(loaded_streamlines, self.nb_streamlines, self.points, [], []) @@ -224,6 +224,6 @@ def test_save_complex_file(self): if cls.can_save_properties(): properties = self.mean_curvature_torsion - loaded_streamlines = nib.streamlines.load(f, ref=self.voxel_to_world, lazy_load=False) + loaded_streamlines = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) check_streamlines(loaded_streamlines, self.nb_streamlines, self.points, scalars, properties) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 0c9c982d89..2a9ea9d29d 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -58,7 +58,7 @@ ('scalar_name', 'S20', 10), (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), ('property_name', 'S20', 10), - (Field.VOXEL_TO_WORLD, 'f4', (4, 4)), # new field for version 2 + (Field.to_world_space, 'f4', (4, 4)), # new field for version 2 ('reserved', 'S444'), (Field.VOXEL_ORDER, 'S4'), ('pad2', 'S4'), @@ -217,7 +217,7 @@ def create_empty_header(cls): header[Field.MAGIC_NUMBER] = TrkFile.MAGIC_NUMBER header[Field.VOXEL_SIZES] = (1, 1, 1) header[Field.DIMENSIONS] = (1, 1, 1) - header[Field.VOXEL_TO_WORLD] = np.eye(4) + header[Field.to_world_space] = np.eye(4) header['version'] = 2 header['hdr_size'] = TrkFile.HEADER_SIZE @@ -238,7 +238,7 @@ def __init__(self, fileobj, header): self.header[Field.NB_SCALARS_PER_POINT] = header.nb_scalars_per_point self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = header.nb_properties_per_streamline self.header[Field.VOXEL_SIZES] = header.voxel_sizes - self.header[Field.VOXEL_TO_WORLD] = header.voxel_to_world + self.header[Field.to_world_space] = header.to_world_space self.header[Field.VOXEL_ORDER] = header.voxel_order # Keep counts for correcting incoherent fields or warn. @@ -392,10 +392,10 @@ def load(fileobj, ref=None, lazy_load=False): trk_reader = TrkReader(fileobj) # Check if reference space matches one from TRK's header. - affine = trk_reader.header[Field.VOXEL_TO_WORLD] + affine = trk_reader.header[Field.to_world_space] if ref is not None: affine = get_affine_from_reference(ref) - if not np.allclose(affine, trk_reader.header[Field.VOXEL_TO_WORLD]): + if not np.allclose(affine, trk_reader.header[Field.to_world_space]): raise ValueError("Reference space provided does not match the " " one from the TRK file header. Use `ref=None`" " to use one contained in the TRK file") @@ -423,7 +423,7 @@ def load(fileobj, ref=None, lazy_load=False): streamlines.properties = [] # Set available common information about streamlines in the header - streamlines.header.voxel_to_world = affine + streamlines.header.to_world_space = affine # If 'count' field is 0, i.e. not provided, we don't set `nb_streamlines` if trk_reader.header[Field.NB_STREAMLINES] > 0: @@ -468,7 +468,7 @@ def save(streamlines, fileobj, ref=None): refers to the center of the voxel. ''' if ref is not None: - streamlines.header.voxel_to_world = get_affine_from_reference(ref) + streamlines.header.to_world_space = get_affine_from_reference(ref) trk_writer = TrkWriter(fileobj, streamlines.header) trk_writer.write(streamlines) @@ -502,7 +502,7 @@ def pretty_print(fileobj): info += "scalar_name:\n {0}".format("\n".join(hdr['scalar_name'])) info += "nb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) info += "property_name:\n {0}".format("\n".join(hdr['property_name'])) - info += "vox_to_world: {0}".format(hdr[Field.VOXEL_TO_WORLD]) + info += "vox_to_world: {0}".format(hdr[Field.to_world_space]) info += "voxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) info += "image_orientation_patient: {0}".format(hdr['image_orientation_patient']) info += "pad1: {0}".format(hdr['pad1']) diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 3b5f3ecf56..7bbbe1ef8d 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -33,7 +33,3 @@ def pop(iterable): "Returns the next item from the iterable else None" value = list(itertools.islice(iterable, 1)) return value[0] if len(value) > 0 else None - -# TODO -def change_space(streamline_file, new_point_space): - pass From 418eb266700ac2a2180a5e8930825048e8e8dd7e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 22 Sep 2015 19:44:21 -0400 Subject: [PATCH 08/54] Added a Streamline class used when indexing/iterating a Streamlines object --- nibabel/streamlines/base_format.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index e5068c69f8..2db3f4e4d6 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -24,6 +24,19 @@ class DataError(Exception): pass +class Streamline(object): + def __init__(self, points, scalars=None, properties=None): + self.points = points + self.scalars = scalars + self.properties = properties + + def __iter__(self): + return iter(self.points) + + def __len__(self): + return len(self.points) + + class Streamlines(object): ''' Class containing information about streamlines. @@ -92,7 +105,8 @@ def properties(self, value): self.header.nb_properties_per_streamline = len(self.properties[0]) def __iter__(self): - return zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) + for data in zip_longest(self.points, self.scalars, self.properties, fillvalue=[]): + yield Streamline(*data) def __getitem__(self, idx): pts = self.points[idx] @@ -107,7 +121,7 @@ def __getitem__(self, idx): if type(idx) is slice: return list(zip_longest(pts, scalars, properties, fillvalue=[])) - return pts, scalars, properties + return Streamline(pts, scalars, properties) def __len__(self): return len(self.points) From b66bcc537b9b88f538a301df731ab6791fcd5147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 28 Oct 2015 12:44:04 -0400 Subject: [PATCH 09/54] Finished merging with master --- nibabel/testing/__init__.py | 46 +++++++++++++++++++++++++++++++++---- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index dccccc001e..d28157479b 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -120,10 +120,12 @@ class clear_and_catch_warnings(warnings.catch_warnings): ... warnings.simplefilter('always') ... # do something that raises a warning in np.core.fromnumeric """ - def __init__(self, *args, **kwargs): - self.modules = kwargs.pop('modules', []) + class_modules = () + + def __init__(self, record=True, modules=()): + self.modules = set(modules).union(self.class_modules) self._warnreg_copies = {} - super(catch_warn_reset, self).__init__(*args, **kwargs) + super(clear_and_catch_warnings, self).__init__(record=record) def __enter__(self): for mod in self.modules: @@ -131,12 +133,46 @@ def __enter__(self): mod_reg = mod.__warningregistry__ self._warnreg_copies[mod] = mod_reg.copy() mod_reg.clear() - return super(catch_warn_reset, self).__enter__() + return super(clear_and_catch_warnings, self).__enter__() def __exit__(self, *exc_info): - super(catch_warn_reset, self).__exit__(*exc_info) + super(clear_and_catch_warnings, self).__exit__(*exc_info) for mod in self.modules: if hasattr(mod, '__warningregistry__'): mod.__warningregistry__.clear() if mod in self._warnreg_copies: mod.__warningregistry__.update(self._warnreg_copies[mod]) + + +class error_warnings(clear_and_catch_warnings): + """ Context manager to check for warnings as errors. Usually used with + ``assert_raises`` in the with block + + Examples + -------- + >>> with error_warnings(): + ... try: + ... warnings.warn('Message', UserWarning) + ... except UserWarning: + ... print('I consider myself warned') + I consider myself warned + """ + filter = 'error' + + def __enter__(self): + mgr = super(error_warnings, self).__enter__() + warnings.simplefilter(self.filter) + return mgr + + +class suppress_warnings(error_warnings): + """ Version of ``catch_warnings`` class that suppresses warnings + """ + filter = 'ignore' + + +class catch_warn_reset(clear_and_catch_warnings): + def __init__(self, *args, **kwargs): + warnings.warn('catch_warn_reset is deprecated and will be removed in ' + 'nibabel v3.0; use nibabel.testing.clear_and_catch_warnings.', + FutureWarning) From c6d93e8c094af66649cd217bdca1f98ea98f7538 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 28 Oct 2015 16:46:50 -0400 Subject: [PATCH 10/54] Fixed tests to use clear_and_catch_warnings --- nibabel/streamlines/base_format.py | 18 ++---------------- nibabel/streamlines/tests/test_base_format.py | 8 ++++---- nibabel/streamlines/tests/test_streamlines.py | 4 ++-- nibabel/streamlines/tests/test_trk.py | 6 +++--- nibabel/testing/__init__.py | 3 +++ 5 files changed, 14 insertions(+), 25 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 2db3f4e4d6..e5068c69f8 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -24,19 +24,6 @@ class DataError(Exception): pass -class Streamline(object): - def __init__(self, points, scalars=None, properties=None): - self.points = points - self.scalars = scalars - self.properties = properties - - def __iter__(self): - return iter(self.points) - - def __len__(self): - return len(self.points) - - class Streamlines(object): ''' Class containing information about streamlines. @@ -105,8 +92,7 @@ def properties(self, value): self.header.nb_properties_per_streamline = len(self.properties[0]) def __iter__(self): - for data in zip_longest(self.points, self.scalars, self.properties, fillvalue=[]): - yield Streamline(*data) + return zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) def __getitem__(self, idx): pts = self.points[idx] @@ -121,7 +107,7 @@ def __getitem__(self, idx): if type(idx) is slice: return list(zip_longest(pts, scalars, properties, fillvalue=[])) - return Streamline(pts, scalars, properties) + return pts, scalars, properties def __len__(self): return len(self.points) diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index f2396da411..c8a9465265 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -4,7 +4,7 @@ import warnings from nibabel.testing import assert_arrays_equal -from nibabel.testing import suppress_warnings, catch_warn_reset +from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal, assert_array_almost_equal from nibabel.externals.six.moves import zip @@ -267,7 +267,7 @@ def test_lazy_streamlines_len(self): scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - with catch_warn_reset(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[base_format]) as w: warnings.simplefilter("always") # Always trigger warnings. # Calling `len` will create new generators each time. @@ -286,7 +286,7 @@ def test_lazy_streamlines_len(self): assert_equal(len(streamlines), self.nb_streamlines) assert_equal(len(w), 2) - with catch_warn_reset(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[base_format]) as w: # Once we iterated through the streamlines, we know the length. streamlines = LazyStreamlines(points, scalars, properties) assert_true(streamlines.header.nb_streamlines is None) @@ -298,7 +298,7 @@ def test_lazy_streamlines_len(self): assert_equal(len(streamlines), len(self.points)) assert_equal(len(w), 0) - with catch_warn_reset(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[base_format]) as w: # It first checks if number of streamlines is in the header. streamlines = LazyStreamlines(points, scalars, properties) streamlines.header.nb_streamlines = 1234 diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index c293ece3b1..893331cb77 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -8,7 +8,7 @@ import nibabel as nib from nibabel.externals.six import BytesIO -from nibabel.testing import catch_warn_reset +from nibabel.testing import clear_and_catch_warnings from nibabel.testing import assert_arrays_equal from nose.tools import assert_equal, assert_raises, assert_true, assert_false @@ -207,7 +207,7 @@ def test_save_complex_file(self): streamlines = Streamlines(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: - with catch_warn_reset(record=True, modules=[trk]) as w: + with clear_and_catch_warnings(record=True, modules=[trk]) as w: nib.streamlines.save(streamlines, f.name) # If streamlines format does not support saving scalars or diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 2d47f5c841..1a1aa2ce05 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -4,7 +4,7 @@ from nibabel.externals.six import BytesIO -from nibabel.testing import suppress_warnings, catch_warn_reset +from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nibabel.testing import assert_arrays_equal, assert_streamlines_equal from nose.tools import assert_equal, assert_raises, assert_true @@ -103,7 +103,7 @@ def test_load_file_with_wrong_information(self): check_streamlines(streamlines, self.nb_streamlines, self.points, [], []) streamlines = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) - with catch_warn_reset(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[base_format]) as w: check_streamlines(streamlines, self.nb_streamlines, self.points, [], []) assert_equal(len(w), 1) assert_true(issubclass(w[0].category, UsageWarning)) @@ -111,7 +111,7 @@ def test_load_file_with_wrong_information(self): # Simulate a TRK file where `voxel_order` was not provided. voxel_order = np.zeros(1, dtype="|S3").tostring() new_trk_file = trk_file[:948] + voxel_order + trk_file[948+3:] - with catch_warn_reset(record=True, modules=[trk]) as w: + with clear_and_catch_warnings(record=True, modules=[trk]) as w: TrkFile.load(BytesIO(new_trk_file), ref=None) assert_equal(len(w), 1) assert_true(issubclass(w[0].category, HeaderWarning)) diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index d28157479b..3de3809477 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -12,8 +12,10 @@ import sys import warnings from os.path import dirname, abspath, join as pjoin +from nibabel.externals.six.moves import zip_longest import numpy as np +from numpy.testing import assert_array_equal # Allow failed import of nose if not now running tests try: @@ -69,6 +71,7 @@ def assert_streamlines_equal(s1, s2): assert_arrays_equal(s1.scalars, s2.scalars) assert_arrays_equal(s1.properties, s2.properties) + def get_fresh_mod(mod_name=__name__): # Get this module, with warning registry empty my_mod = sys.modules[mod_name] From fcd962fe75317a36800eda482360eada71739593 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 29 Oct 2015 13:45:34 -0400 Subject: [PATCH 11/54] Added the CompactList data structure to keep points and scalars --- nibabel/benchmarks/bench_streamlines.py | 33 ++-- nibabel/streamlines/base_format.py | 165 +++++++++++++++++- nibabel/streamlines/tests/test_base_format.py | 37 ++-- nibabel/streamlines/tests/test_streamlines.py | 2 +- nibabel/streamlines/tests/test_trk.py | 52 +++--- nibabel/streamlines/trk.py | 62 ++++++- 6 files changed, 281 insertions(+), 70 deletions(-) diff --git a/nibabel/benchmarks/bench_streamlines.py b/nibabel/benchmarks/bench_streamlines.py index 3a2e3ab39d..4eba60be7d 100644 --- a/nibabel/benchmarks/bench_streamlines.py +++ b/nibabel/benchmarks/bench_streamlines.py @@ -40,13 +40,21 @@ def bench_load_trk(): repeat = 20 trk_file = BytesIO() - trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES)) - tv.write(trk_file, trk) + #trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES)) + #tv.write(trk_file, trk) + streamlines = Streamlines(points) + TrkFile.save(streamlines, trk_file) + + from pycallgraph import PyCallGraph + from pycallgraph.output import GraphvizOutput - mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, lazy_load=False)', repeat) - print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) + with PyCallGraph(output=GraphvizOutput()): + #nib.streamlines.load(trk_file, ref=None, lazy_load=False) - mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file)', repeat) + mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, ref=None, lazy_load=False)', repeat) + print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) + + mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file, points_space="voxel")', repeat) print("Old: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) print("Speedup of %2f" % (mtime_old/mtime_new)) @@ -54,13 +62,15 @@ def bench_load_trk(): scalars = [np.random.rand(NB_POINTS, 10).astype('float32') for i in range(NB_STREAMLINES)] trk_file = BytesIO() - trk = list(zip(points, scalars, [None]*NB_STREAMLINES)) - tv.write(trk_file, trk) + #trk = list(zip(points, scalars, [None]*NB_STREAMLINES)) + #tv.write(trk_file, trk) + streamlines = Streamlines(points, scalars) + TrkFile.save(streamlines, trk_file) - mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, lazy_load=False)', repeat) + mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, ref=None, lazy_load=False)', repeat) print("New: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new)) - mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file)', repeat) + mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file, points_space="voxel")', repeat) print("Old: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old)) print("Speedup of %2f" % (mtime_old/mtime_new)) @@ -91,8 +101,7 @@ def bench_save_trk(): for pts, A in zip(points, streams): assert_array_equal(pts, A[0]) - trk = nib.streamlines.load(trk_file_new, lazy_load=False) - + trk = nib.streamlines.load(trk_file_new, ref=None, lazy_load=False) assert_arrays_equal(points, trk.points) # Points and scalars @@ -117,7 +126,7 @@ def bench_save_trk(): assert_array_equal(pts, A[0]) assert_array_equal(scal, A[1]) - trk = nib.streamlines.load(trk_file_new, lazy_load=False) + trk = nib.streamlines.load(trk_file_new, ref=None, lazy_load=False) assert_arrays_equal(points, trk.points) assert_arrays_equal(scalars, trk.scalars) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index e5068c69f8..8c6f3a9973 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -24,6 +24,128 @@ class DataError(Exception): pass +class CompactList(object): + def __init__(self, preallocate=(0,), dtype=np.float32): + self.dtype = dtype + self.data = np.empty(preallocate, dtype=dtype) + self.offsets = [] + self.lengths = [] + + @classmethod + def from_list(cls, elements): + """ Fast way to create a `Streamlines` object from some streamlines. + Parameters + ---------- + elements : list + List of 2D ndarrays of same shape except on the first dimension. + """ + if len(elements) == 0: + return cls() + + first_element = np.asarray(elements[0]) + s = cls(preallocate=(0,) + first_element.shape[1:], dtype=first_element.dtype) + s.extend(elements) + return s + + def append(self, element): + """ + Parameters + ---------- + element : 2D ndarrays of `element.shape[1:] == self.shape` + Element to add. + Note + ---- + If you need to add multiple elements you should consider + `CompactList.from_list` or `CompactList.extend`. + """ + self.offsets.append(len(self.data)) + self.lengths.append(len(element)) + self.data = np.append(self.data, element, axis=0) + + def extend(self, elements): + if isinstance(elements, CompactList): + self.data = np.concatenate([self.data, elements.data], axis=0) + offset = self.offsets[-1] + self.lengths[-1] if len(self) > 0 else 0 + self.lengths.extend(elements.lengths) + self.offsets.extend(np.cumsum([offset] + elements.lengths).tolist()[:-1]) + else: + self.data = np.concatenate([self.data] + list(elements), axis=0) + offset = self.offsets[-1] + self.lengths[-1] if len(self) > 0 else 0 + lengths = map(len, elements) + self.lengths.extend(lengths) + self.offsets.extend(np.cumsum([offset] + lengths).tolist()[:-1]) + + def __getitem__(self, idx): + """ Gets element(s) through indexing. + Parameters + ---------- + idx : int, slice or list + Index of the element(s) to get. + Returns + ------- + `ndarray` object(s) + When `idx` is a int, returns a single 2D array. + When `idx` is either a slice or a list, returns a list of 2D arrays. + """ + if isinstance(idx, int) or isinstance(idx, np.integer): + return self.data[self.offsets[idx]:self.offsets[idx]+self.lengths[idx]] + + elif type(idx) is slice: + compact_list = CompactList() + compact_list.data = self.data + compact_list.offsets = self.offsets[idx] + compact_list.lengths = self.lengths[idx] + return compact_list + + elif type(idx) is list: + compact_list = CompactList() + compact_list.data = self.data + compact_list.offsets = [self.offsets[i] for i in idx] + compact_list.lengths = [self.lengths[i] for i in idx] + return compact_list + + raise TypeError("Index must be a int or a slice! Not " + str(type(idx))) + + def copy(self): + # We could not only deepcopy the object because when slicing a CompactList it returns + # a view with modified `lengths` and `offsets` but `data` still points to the original data. + compact_list = CompactList() + total_lengths = np.sum(self.lengths) + compact_list.data = np.empty((total_lengths,) + self.data.shape[1:], dtype=self.dtype) + + cur_offset = 0 + for offset, lengths in zip(self.offsets, self.lengths): + compact_list.offsets.append(cur_offset) + compact_list.lengths.append(lengths) + compact_list.data[cur_offset:cur_offset+lengths] = self.data[offset:offset+lengths] + cur_offset += lengths + + return compact_list + + def __iter__(self): + if len(self.lengths) != len(self.offsets): + raise ValueError("CompactList object corrupted: len(self.lengths) != len(self.offsets)") + + for offset, lengths in zip(self.offsets, self.lengths): + yield self.data[offset: offset+lengths] + + def __len__(self): + return len(self.offsets) + + +class Streamline(object): + def __init__(self, points, scalars=None, properties=None): + self.points = points + self.scalars = scalars + self.properties = properties + + def __iter__(self): + return iter(self.points) + + def __len__(self): + return len(self.points) + + class Streamlines(object): ''' Class containing information about streamlines. @@ -64,7 +186,18 @@ def points(self): @points.setter def points(self, value): - self._points = value if value else [] + if value is None or len(value) == 0: + self._points = CompactList(preallocate=(0, 3)) + + elif isinstance(value, CompactList): + self._points = value + + elif isinstance(value, list) or isinstance(value, tuple): + self._points = CompactList.from_list(value) + + else: + raise DataError("Unsupported data type: {0}".format(type(value))) + self.header.nb_streamlines = len(self.points) @property @@ -73,9 +206,19 @@ def scalars(self): @scalars.setter def scalars(self, value): - self._scalars = value if value else [] - self.header.nb_scalars_per_point = 0 + if value is None or len(value) == 0: + self._scalars = CompactList() + + elif isinstance(value, CompactList): + self._scalars = value + + elif isinstance(value, list) or isinstance(value, tuple): + self._scalars = CompactList.from_list(value) + else: + raise DataError("Unsupported data type: {0}".format(type(value))) + + self.header.nb_scalars_per_point = 0 if len(self.scalars) > 0 and len(self.scalars[0]) > 0: self.header.nb_scalars_per_point = len(self.scalars[0][0]) @@ -85,14 +228,18 @@ def properties(self): @properties.setter def properties(self, value): - self._properties = value if value else [] + if value is None: + value = [] + + self._properties = np.asarray(value, dtype=np.float32) self.header.nb_properties_per_streamline = 0 if len(self.properties) > 0: self.header.nb_properties_per_streamline = len(self.properties[0]) def __iter__(self): - return zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) + for data in zip_longest(self.points, self.scalars, self.properties, fillvalue=[]): + yield Streamline(*data) def __getitem__(self, idx): pts = self.points[idx] @@ -105,16 +252,16 @@ def __getitem__(self, idx): properties = self.properties[idx] if type(idx) is slice: - return list(zip_longest(pts, scalars, properties, fillvalue=[])) + return Streamlines(pts, scalars, properties) - return pts, scalars, properties + return Streamline(pts, scalars, properties) def __len__(self): return len(self.points) def copy(self): """ Returns a copy of this `Streamlines` object. """ - streamlines = Streamlines(self.points, self.scalars, self.properties) + streamlines = Streamlines(self.points.copy(), self.scalars.copy(), self.properties.copy()) streamlines._header = self.header.copy() return streamlines @@ -280,7 +427,7 @@ def __getitem__(self, idx): def __iter__(self): i = 0 for i, s in enumerate(self._data(), start=1): - yield s + yield Streamline(*s) # To be safe, update information about number of streamlines. self.header.nb_streamlines = i diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index c8a9465265..f7164eca5e 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -48,7 +48,7 @@ def test_streamlines_creation_from_arrays(self): assert_arrays_equal(streamlines.properties, []) # Check if we can iterate through the streamlines. - for points, scalars, props in streamlines: + for streamline in streamlines: pass # Only points @@ -59,7 +59,7 @@ def test_streamlines_creation_from_arrays(self): assert_arrays_equal(streamlines.properties, []) # Check if we can iterate through the streamlines. - for points, scalars, props in streamlines: + for streamline in streamlines: pass # Points, scalars and properties @@ -70,9 +70,11 @@ def test_streamlines_creation_from_arrays(self): assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) # Check if we can iterate through the streamlines. - for point, scalar, prop in streamlines: + for streamline in streamlines: pass + #streamlines = Streamlines(self.points, scalars) + def test_streamlines_getter(self): # Streamlines with only points streamlines = Streamlines(points=self.points) @@ -80,27 +82,24 @@ def test_streamlines_getter(self): selected_streamlines = streamlines[::2] assert_equal(len(selected_streamlines), (len(self.points)+1)//2) - points, scalars, properties = zip(*selected_streamlines) - assert_arrays_equal(points, self.points[::2]) - assert_equal(sum(map(len, scalars)), 0) - assert_equal(sum(map(len, properties)), 0) + assert_arrays_equal(selected_streamlines.points, self.points[::2]) + assert_equal(sum(map(len, selected_streamlines.scalars)), 0) + assert_equal(sum(map(len, selected_streamlines.properties)), 0) # Streamlines with points, scalars and properties streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) # Retrieve streamlines by their index - for i, (points, scalars, props) in enumerate(streamlines): - points_i, scalars_i, props_i = streamlines[i] - assert_array_equal(points_i, points) - assert_array_equal(scalars_i, scalars) - assert_array_equal(props_i, props) + for i, streamline in enumerate(streamlines): + assert_array_equal(streamline.points, streamlines[i].points) + assert_array_equal(streamline.scalars, streamlines[i].scalars) + assert_array_equal(streamline.properties, streamlines[i].properties) # Use slicing r_streamlines = streamlines[::-1] - r_points, r_scalars, r_props = zip(*r_streamlines) - assert_arrays_equal(r_points, self.points[::-1]) - assert_arrays_equal(r_scalars, self.colors[::-1]) - assert_arrays_equal(r_props, self.mean_curvature_torsion[::-1]) + assert_arrays_equal(r_streamlines.points, self.points[::-1]) + assert_arrays_equal(r_streamlines.scalars, self.colors[::-1]) + assert_arrays_equal(r_streamlines.properties, self.mean_curvature_torsion[::-1]) def test_streamlines_creation_from_coroutines(self): # Points, scalars and properties @@ -199,7 +198,7 @@ def test_lazy_streamlines_creation(self): assert_arrays_equal(streamlines.properties, []) # Check if we can iterate through the streamlines. - for point, scalar, prop in streamlines: + for streamline in streamlines: pass # Points, scalars and properties @@ -230,7 +229,7 @@ def test_lazy_streamlines_creation(self): assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) # Check if we can iterate through the streamlines. - for point, scalar, prop in streamlines: + for streamline in streamlines: pass def test_lazy_streamlines_indexing(self): @@ -290,7 +289,7 @@ def test_lazy_streamlines_len(self): # Once we iterated through the streamlines, we know the length. streamlines = LazyStreamlines(points, scalars, properties) assert_true(streamlines.header.nb_streamlines is None) - for s in streamlines: + for streamline in streamlines: pass assert_equal(streamlines.header.nb_streamlines, len(self.points)) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 893331cb77..e6ab8cab17 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -22,7 +22,7 @@ def isiterable(streamlines): try: - for point, scalar, prop in streamlines: + for _ in streamlines: pass except: return False diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 1a1aa2ce05..e29b288dc5 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -20,7 +20,7 @@ def isiterable(streamlines): try: - for point, scalar, prop in streamlines: + for streamline in streamlines: pass except: return False @@ -201,29 +201,33 @@ def test_write_erroneous_file(self): streamlines = Streamlines(self.points, scalars) assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) - # Inconsistent number of scalars between points - scalars = [[(1, 0, 0)]*1, - [(0, 1, 0), (0, 1)], - [(0, 0, 1)]*5] - - streamlines = Streamlines(self.points, scalars) - assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) - - # Inconsistent number of scalars between streamlines - scalars = [[(1, 0, 0)]*1, - [(0, 1)]*2, - [(0, 0, 1)]*5] - - streamlines = Streamlines(self.points, scalars) - assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) - - # Inconsistent number of properties - properties = [np.array([1.11, 1.22], dtype="f4"), - np.array([2.11], dtype="f4"), - np.array([3.11, 3.22], dtype="f4")] - streamlines = Streamlines(self.points, properties=properties) - assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) - + # # Unit test moved to test_base_format.py + # # Inconsistent number of scalars between points + # scalars = [[(1, 0, 0)]*1, + # [(0, 1, 0), (0, 1)], + # [(0, 0, 1)]*5] + + #streamlines = Streamlines(self.points, scalars) + #assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) + + # # Unit test moved to test_base_format.py + # # Inconsistent number of scalars between streamlines + # scalars = [[(1, 0, 0)]*1, + # [(0, 1)]*2, + # [(0, 0, 1)]*5] + + # streamlines = Streamlines(self.points, scalars) + # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + # # Unit test moved to test_base_format.py + # # Inconsistent number of properties + # properties = [np.array([1.11, 1.22], dtype="f4"), + # np.array([2.11], dtype="f4"), + # np.array([3.11, 3.22], dtype="f4")] + # streamlines = Streamlines(self.points, properties=properties) + # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + # # Unit test moved to test_base_format.py # No properties for every streamlines properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11, 2.22], dtype="f4")] diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 2a9ea9d29d..70df0385d0 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -12,6 +12,7 @@ from nibabel.openers import Opener from nibabel.volumeutils import (native_code, swapped_code) +from nibabel.streamlines.base_format import CompactList from nibabel.streamlines.base_format import StreamlinesFile from nibabel.streamlines.base_format import DataError, HeaderError, HeaderWarning from nibabel.streamlines.base_format import Streamlines, LazyStreamlines @@ -139,11 +140,23 @@ def __init__(self, fileobj): # Keep the file position where the data begin. self.offset_data = f.tell() + if f.name is not None and self.header[Field.NB_STREAMLINES] > 0: + filesize = os.path.getsize(f.name) - self.offset_data + # Remove properties + filesize -= self.header[Field.NB_STREAMLINES] * self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * 4. + # Remove the points count at the beginning of each streamline. + filesize -= self.header[Field.NB_STREAMLINES] * 4. + # Get nb points. + nb_points = filesize / ((3 + self.header[Field.NB_SCALARS_PER_POINT]) * 4.) + self.header[Field.NB_POINTS] = int(nb_points) + def __iter__(self): i4_dtype = np.dtype(self.endianness + "i4") f4_dtype = np.dtype(self.endianness + "f4") + #from io import BufferedReader with Opener(self.fileobj) as f: + #f = BufferedReader(f.fobj) start_position = f.tell() nb_pts_and_scalars = int(3 + self.header[Field.NB_SCALARS_PER_POINT]) @@ -257,13 +270,13 @@ def write(self, streamlines): i4_dtype = np.dtype("i4") f4_dtype = np.dtype("f4") - for points, scalars, properties in streamlines: - if len(scalars) > 0 and len(scalars) != len(points): + for s in streamlines: + if len(s.scalars) > 0 and len(s.scalars) != len(s.points): raise DataError("Missing scalars for some points!") - points = np.array(points, dtype=f4_dtype) - scalars = np.array(scalars, dtype=f4_dtype).reshape((len(points), -1)) - properties = np.array(properties, dtype=f4_dtype) + points = np.asarray(s.points, dtype=f4_dtype) + scalars = np.asarray(s.scalars, dtype=f4_dtype).reshape((len(points), -1)) + properties = np.asarray(s.properties, dtype=f4_dtype) # TRK's streamlines need to be in 'voxelmm' space points = points * self.header[Field.VOXEL_SIZES] @@ -413,6 +426,45 @@ def load(fileobj, ref=None, lazy_load=False): streamlines.scalars = lambda: [] if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: streamlines.properties = lambda: [] + elif Field.NB_POINTS in trk_reader.header: + # 'count' field is provided, we can avoid creating list of numpy + # arrays (more memory efficient). + + nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] + nb_points = trk_reader.header[Field.NB_POINTS] + + points = CompactList(preallocate=(nb_points, 3)) + scalars = CompactList(preallocate=(nb_points, trk_reader.header[Field.NB_SCALARS_PER_POINT])) + properties = np.empty((nb_streamlines, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]), + dtype=np.float32) + + offset = 0 + offsets = [] + lengths = [] + for i, (pts, scals, props) in enumerate(data()): + offsets.append(offset) + lengths.append(len(pts)) + try: + points.data[offset:offset+len(pts)] = pts + except: + from ipdb import set_trace as dbg + dbg() + scalars.data[offset:offset+len(scals)] = scals + properties[i] = props + offset += len(pts) + + points.offsets = offsets + scalars.offsets = offsets + points.lengths = lengths + scalars.lengths = lengths + streamlines = Streamlines(points, scalars, properties) + + # Overwrite scalars and properties if there is none + if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: + streamlines.scalars = [] + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: + streamlines.properties = [] + else: streamlines = Streamlines(*zip(*data())) From b8dcee992477fa2c4ed9dcff5ef27d73c584d293 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 30 Oct 2015 23:08:55 -0400 Subject: [PATCH 12/54] Old unit tests are all passing --- nibabel/__init__.py | 2 +- nibabel/benchmarks/bench_streamlines.py | 12 +- nibabel/streamlines/__init__.py | 2 +- nibabel/streamlines/base_format.py | 420 +++++++++++------- nibabel/streamlines/header.py | 4 +- nibabel/streamlines/tests/test_base_format.py | 305 ++++++++++--- nibabel/streamlines/tests/test_header.py | 4 +- nibabel/streamlines/tests/test_streamlines.py | 16 +- nibabel/streamlines/tests/test_trk.py | 24 +- nibabel/streamlines/trk.py | 168 +++---- 10 files changed, 633 insertions(+), 324 deletions(-) diff --git a/nibabel/__init__.py b/nibabel/__init__.py index 3bf1fa4cd3..d40a388487 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -63,7 +63,7 @@ apply_orientation, aff2axcodes) from .imageclasses import class_map, ext_map, all_image_classes from . import trackvis -from .streamlines import Streamlines +from .streamlines import Tractogram from . import mriutils # be friendly on systems with ancient numpy -- no tests, but at least diff --git a/nibabel/benchmarks/bench_streamlines.py b/nibabel/benchmarks/bench_streamlines.py index 4eba60be7d..3d0da178c1 100644 --- a/nibabel/benchmarks/bench_streamlines.py +++ b/nibabel/benchmarks/bench_streamlines.py @@ -45,14 +45,14 @@ def bench_load_trk(): streamlines = Streamlines(points) TrkFile.save(streamlines, trk_file) - from pycallgraph import PyCallGraph - from pycallgraph.output import GraphvizOutput + # from pycallgraph import PyCallGraph + # from pycallgraph.output import GraphvizOutput - with PyCallGraph(output=GraphvizOutput()): - #nib.streamlines.load(trk_file, ref=None, lazy_load=False) + # with PyCallGraph(output=GraphvizOutput()): + # #nib.streamlines.load(trk_file, ref=None, lazy_load=False) - mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, ref=None, lazy_load=False)', repeat) - print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) + mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, ref=None, lazy_load=False)', repeat) + print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file, points_space="voxel")', repeat) print("Old: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 7f3d7dcdf2..e23757013f 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,5 +1,5 @@ from .header import Field -from .base_format import Streamlines, LazyStreamlines +from .base_format import Tractogram, LazyTractogram from nibabel.streamlines.trk import TrkFile #from nibabel.streamlines.tck import TckFile diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 8c6f3a9973..95fc2669b5 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -1,10 +1,11 @@ +import itertools import numpy as np from warnings import warn from nibabel.externals.six.moves import zip_longest from nibabel.affines import apply_affine -from .header import StreamlinesHeader +from .header import TractogramHeader from .utils import pop @@ -25,119 +26,198 @@ class DataError(Exception): class CompactList(object): - def __init__(self, preallocate=(0,), dtype=np.float32): - self.dtype = dtype - self.data = np.empty(preallocate, dtype=dtype) - self.offsets = [] - self.lengths = [] - - @classmethod - def from_list(cls, elements): - """ Fast way to create a `Streamlines` object from some streamlines. + """ Class for compacting list of ndarrays with matching shape except for + the first dimension. + """ + def __init__(self, iterable=None): + """ Parameters ---------- - elements : list - List of 2D ndarrays of same shape except on the first dimension. + iterable : iterable (optional) + If specified, create a ``CompactList`` object initialized from + iterable's items. Otherwise, create an empty ``CompactList``. """ - if len(elements) == 0: - return cls() + # Create new empty `CompactList` object. + self._data = None + self._offsets = [] + self._lengths = [] + + if iterable is not None: + # Initialize the `CompactList` object from iterable's item. + BUFFER_SIZE = 1000 + + offset = 0 + for i, e in enumerate(iterable): + e = np.asarray(e) + if i == 0: + self._data = np.empty((BUFFER_SIZE,) + e.shape[1:], dtype=e.dtype) + + end = offset + len(e) + if end >= len(self._data): + # Resize is needed (at least `len(e)` items will be added). + self._data.resize((len(self._data) + len(e)+BUFFER_SIZE,) + self.shape) + + self._offsets.append(offset) + self._lengths.append(len(e)) + self._data[offset:offset+len(e)] = e + offset += len(e) + + # Clear unused memory. + if self._data is not None: + self._data.resize((offset,) + self.shape) - first_element = np.asarray(elements[0]) - s = cls(preallocate=(0,) + first_element.shape[1:], dtype=first_element.dtype) - s.extend(elements) - return s + @property + def shape(self): + """ Returns the matching shape of the elements in this compact list. """ + if self._data is None: + return None + + return self._data.shape[1:] def append(self, element): - """ + """ Appends `element` to this compact list. + Parameters ---------- - element : 2D ndarrays of `element.shape[1:] == self.shape` - Element to add. - Note - ---- + element : ndarray + Element to append. The shape must match already inserted elements + shape except for the first dimension. + + Notes + ----- If you need to add multiple elements you should consider - `CompactList.from_list` or `CompactList.extend`. + `CompactList.extend`. """ - self.offsets.append(len(self.data)) - self.lengths.append(len(element)) - self.data = np.append(self.data, element, axis=0) + if self._data is None: + self._data = np.asarray(element).copy() + self._offsets.append(0) + self._lengths.append(len(element)) + return + + if element.shape[1:] != self.shape: + raise ValueError("All dimensions, except the first one, must match exactly") + + self._offsets.append(len(self._data)) + self._lengths.append(len(element)) + self._data = np.append(self._data, element, axis=0) def extend(self, elements): + """ Appends all `elements` to this compact list. + + Parameters + ---------- + element : list of ndarrays, ``CompactList`` object + Elements to append. The shape must match already inserted elements + shape except for the first dimension. + """ if isinstance(elements, CompactList): - self.data = np.concatenate([self.data, elements.data], axis=0) - offset = self.offsets[-1] + self.lengths[-1] if len(self) > 0 else 0 - self.lengths.extend(elements.lengths) - self.offsets.extend(np.cumsum([offset] + elements.lengths).tolist()[:-1]) + self._data = np.concatenate([self._data, elements._data], axis=0) + offset = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 + self._lengths.extend(elements._lengths) + self._offsets.extend(np.cumsum([offset] + elements._lengths).tolist()[:-1]) else: - self.data = np.concatenate([self.data] + list(elements), axis=0) - offset = self.offsets[-1] + self.lengths[-1] if len(self) > 0 else 0 + self._data = np.concatenate([self._data] + list(elements), axis=0) + offset = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 lengths = map(len, elements) - self.lengths.extend(lengths) - self.offsets.extend(np.cumsum([offset] + lengths).tolist()[:-1]) + self._lengths.extend(lengths) + self._offsets.extend(np.cumsum([offset] + lengths).tolist()[:-1]) + + def copy(self): + """ Creates a copy of this ``CompactList`` object. """ + # We cannot just deepcopy this object since we don't know if it has been created + # using slicing. If it is the case, `self.data` probably contains more data than necessary + # so we copy only elements according to `self._offsets`. + compact_list = CompactList() + total_lengths = np.sum(self._lengths) + compact_list._data = np.empty((total_lengths,) + self._data.shape[1:], dtype=self._data.dtype) + + cur_offset = 0 + for offset, lengths in zip(self._offsets, self._lengths): + compact_list._offsets.append(cur_offset) + compact_list._lengths.append(lengths) + compact_list._data[cur_offset:cur_offset+lengths] = self._data[offset:offset+lengths] + cur_offset += lengths + + return compact_list def __getitem__(self, idx): """ Gets element(s) through indexing. + Parameters ---------- idx : int, slice or list Index of the element(s) to get. + Returns ------- - `ndarray` object(s) - When `idx` is a int, returns a single 2D array. - When `idx` is either a slice or a list, returns a list of 2D arrays. + ndarray object(s) + When `idx` is a int, returns a single ndarray. + When `idx` is either a slice or a list, returns a list of ndarrays. """ if isinstance(idx, int) or isinstance(idx, np.integer): - return self.data[self.offsets[idx]:self.offsets[idx]+self.lengths[idx]] + return self._data[self._offsets[idx]:self._offsets[idx]+self._lengths[idx]] elif type(idx) is slice: + # TODO: Should we have a CompactListView class that would be + # returned when slicing? compact_list = CompactList() - compact_list.data = self.data - compact_list.offsets = self.offsets[idx] - compact_list.lengths = self.lengths[idx] + compact_list._data = self._data + compact_list._offsets = self._offsets[idx] + compact_list._lengths = self._lengths[idx] return compact_list elif type(idx) is list: + # TODO: Should we have a CompactListView class that would be + # returned when doing advance indexing? compact_list = CompactList() - compact_list.data = self.data - compact_list.offsets = [self.offsets[i] for i in idx] - compact_list.lengths = [self.lengths[i] for i in idx] + compact_list._data = self._data + compact_list._offsets = [self._offsets[i] for i in idx] + compact_list._lengths = [self._lengths[i] for i in idx] return compact_list raise TypeError("Index must be a int or a slice! Not " + str(type(idx))) - def copy(self): - # We could not only deepcopy the object because when slicing a CompactList it returns - # a view with modified `lengths` and `offsets` but `data` still points to the original data. - compact_list = CompactList() - total_lengths = np.sum(self.lengths) - compact_list.data = np.empty((total_lengths,) + self.data.shape[1:], dtype=self.dtype) - - cur_offset = 0 - for offset, lengths in zip(self.offsets, self.lengths): - compact_list.offsets.append(cur_offset) - compact_list.lengths.append(lengths) - compact_list.data[cur_offset:cur_offset+lengths] = self.data[offset:offset+lengths] - cur_offset += lengths - - return compact_list - def __iter__(self): - if len(self.lengths) != len(self.offsets): - raise ValueError("CompactList object corrupted: len(self.lengths) != len(self.offsets)") + if len(self._lengths) != len(self._offsets): + raise ValueError("CompactList object corrupted: len(self._lengths) != len(self._offsets)") - for offset, lengths in zip(self.offsets, self.lengths): - yield self.data[offset: offset+lengths] + for offset, lengths in zip(self._offsets, self._lengths): + yield self._data[offset: offset+lengths] def __len__(self): - return len(self.offsets) + return len(self._offsets) + def __repr__(self): + return repr(list(self)) -class Streamline(object): +class TractogramItem(object): + ''' Class containing information about one streamline. + + ``TractogramItem`` objects have three main properties: `points`, `scalars` + and ``properties``. + + Parameters + ---------- + points : ndarray of shape (N, 3) + Points of this streamline represented as an ndarray of shape (N, 3) + where N is the number of points. + + scalars : ndarray of shape (N, M) + Scalars associated with each point of this streamline and represented + as an ndarray of shape (N, M) where N is the number of points and + M is the number of scalars (excluding the three coordinates). + + properties : ndarray of shape (P,) + Properties associated with this streamline and represented as an + ndarray of shape (P,) where P is the number of properties. + ''' def __init__(self, points, scalars=None, properties=None): - self.points = points - self.scalars = scalars - self.properties = properties + #if scalars is not None and len(points) != len(scalars): + # raise ValueError("First dimension of points and scalars must match.") + + self.points = np.asarray(points) + self.scalars = np.asarray([] if scalars is None else scalars) + self.properties = np.asarray([] if properties is None else properties) def __iter__(self): return iter(self.points) @@ -146,11 +226,11 @@ def __len__(self): return len(self.points) -class Streamlines(object): +class Tractogram(object): ''' Class containing information about streamlines. - Streamlines objects have three main properties: ``points``, ``scalars`` - and ``properties``. Streamlines objects can be iterate over producing + Tractogram objects have three main properties: ``points``, ``scalars`` + and ``properties``. Tractogram objects can be iterate over producing tuple of ``points``, ``scalars`` and ``properties`` for each streamline. Parameters @@ -171,11 +251,71 @@ class Streamlines(object): associated to each streamline. ''' def __init__(self, points=None, scalars=None, properties=None): - self._header = StreamlinesHeader() + self._header = TractogramHeader() self.points = points self.scalars = scalars self.properties = properties + @classmethod + def create_from_generator(cls, gen): + BUFFER_SIZE = 1000 + + points = CompactList() + scalars = CompactList() + properties = np.array([]) + + gen = iter(gen) + try: + first_element = next(gen) + gen = itertools.chain([first_element], gen) + except StopIteration: + return cls(points, scalars, properties) + + # Allocated some buffer memory. + pts = np.asarray(first_element[0]) + scals = np.asarray(first_element[1]) + props = np.asarray(first_element[2]) + + points._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) + scalars._data = np.empty((BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) + properties = np.empty((BUFFER_SIZE, props.shape[0]), dtype=props.dtype) + + offset = 0 + for i, (pts, scals, props) in enumerate(gen): + pts = np.asarray(pts) + scals = np.asarray(scals) + props = np.asarray(props) + + end = offset + len(pts) + if end >= len(points._data): + # Resize is needed (at least `len(pts)` items will be added). + points._data.resize((len(points._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) + scalars._data.resize((len(scalars._data) + len(scalars)+BUFFER_SIZE, scals.shape[1])) + + points._offsets.append(offset) + points._lengths.append(len(pts)) + points._data[offset:offset+len(pts)] = pts + scalars._data[offset:offset+len(scals)] = scals + + offset += len(pts) + + if i >= len(properties): + properties.resize((len(properties) + BUFFER_SIZE, props.shape[0])) + + properties[i] = props + + # Share offsets and lengths between points and scalars. + scalars._offsets = points._offsets + scalars._lengths = points._lengths + + # Clear unused memory. + points._data.resize((offset, pts.shape[1])) + scalars._data.resize((offset, scals.shape[1])) + properties.resize((i+1, props.shape[0])) + + return cls(points, scalars, properties) + + @property def header(self): return self._header @@ -186,17 +326,9 @@ def points(self): @points.setter def points(self, value): - if value is None or len(value) == 0: - self._points = CompactList(preallocate=(0, 3)) - - elif isinstance(value, CompactList): - self._points = value - - elif isinstance(value, list) or isinstance(value, tuple): - self._points = CompactList.from_list(value) - - else: - raise DataError("Unsupported data type: {0}".format(type(value))) + self._points = value + if not isinstance(value, CompactList): + self._points = CompactList(value) self.header.nb_streamlines = len(self.points) @@ -206,17 +338,9 @@ def scalars(self): @scalars.setter def scalars(self, value): - if value is None or len(value) == 0: - self._scalars = CompactList() - - elif isinstance(value, CompactList): - self._scalars = value - - elif isinstance(value, list) or isinstance(value, tuple): - self._scalars = CompactList.from_list(value) - - else: - raise DataError("Unsupported data type: {0}".format(type(value))) + self._scalars = value + if not isinstance(value, CompactList): + self._scalars = CompactList(value) self.header.nb_scalars_per_point = 0 if len(self.scalars) > 0 and len(self.scalars[0]) > 0: @@ -228,18 +352,17 @@ def properties(self): @properties.setter def properties(self, value): + self._properties = np.asarray(value) if value is None: - value = [] + self._properties = np.empty((len(self), 0), dtype=np.float32) - self._properties = np.asarray(value, dtype=np.float32) self.header.nb_properties_per_streamline = 0 - if len(self.properties) > 0: self.header.nb_properties_per_streamline = len(self.properties[0]) def __iter__(self): - for data in zip_longest(self.points, self.scalars, self.properties, fillvalue=[]): - yield Streamline(*data) + for data in zip_longest(self.points, self.scalars, self.properties, fillvalue=None): + yield TractogramItem(*data) def __getitem__(self, idx): pts = self.points[idx] @@ -252,70 +375,43 @@ def __getitem__(self, idx): properties = self.properties[idx] if type(idx) is slice: - return Streamlines(pts, scalars, properties) + return Tractogram(pts, scalars, properties) - return Streamline(pts, scalars, properties) + return TractogramItem(pts, scalars, properties) def __len__(self): return len(self.points) def copy(self): - """ Returns a copy of this `Streamlines` object. """ - streamlines = Streamlines(self.points.copy(), self.scalars.copy(), self.properties.copy()) + """ Returns a copy of this `Tractogram` object. """ + streamlines = Tractogram(self.points.copy(), self.scalars.copy(), self.properties.copy()) streamlines._header = self.header.copy() return streamlines - def transform(self, affine, lazy=False): + def apply_affine(self, affine): """ Applies an affine transformation on the points of each streamline. + This is performed in-place. + Parameters ---------- affine : 2D array (4,4) Transformation that will be applied on each streamline. - lazy : bool (optional) - If true output will be a generator of arrays instead of a list. - - Returns - ------- - streamlines - If `lazy` is true, a `LazyStreamlines` object is returned, - otherwise a `Streamlines` object is returned. In both case, - streamlines are in a space defined by `affine`. """ - points = lambda: (apply_affine(affine, pts) for pts in self.points) - if not lazy: - points = list(points()) - - streamlines = self.copy() - streamlines.points = points - streamlines.header.to_world_space = np.dot(streamlines.header.to_world_space, - np.linalg.inv(affine)) + if len(self.points) == 0: + return - return streamlines - - def to_world_space(self, lazy=False): - """ Sends the streamlines back into world space. - - Parameters - ---------- - lazy : bool (optional) - If true output will be a generator of arrays instead of a list. - - Returns - ------- - streamlines - If `lazy` is true, a `LazyStreamlines` object is returned, - otherwise a `Streamlines` object is returned. In both case, - streamlines are in world space. - """ - return self.transform(self.header.to_world_space, lazy) + BUFFER_SIZE = 10000 + for i in range(0, len(self.points._data), BUFFER_SIZE): + pts = self.points._data[i:i+BUFFER_SIZE] + self.points._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) -class LazyStreamlines(Streamlines): +class LazyTractogram(Tractogram): ''' Class containing information about streamlines. - Streamlines objects have four main properties: ``header``, ``points``, - ``scalars`` and ``properties``. Streamlines objects are iterable and + Tractogram objects have four main properties: ``header``, ``points``, + ``scalars`` and ``properties``. Tractogram objects are iterable and produce tuple of ``points``, ``scalars`` and ``properties`` for each streamline. @@ -347,7 +443,7 @@ class LazyStreamlines(Streamlines): values as ``points``. ''' def __init__(self, points_func=lambda:[], scalars_func=lambda: [], properties_func=lambda: [], getitem_func=None): - super(LazyStreamlines, self).__init__(points_func, scalars_func, properties_func) + super(LazyTractogram, self).__init__(points_func, scalars_func, properties_func) self._data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) self._getitem = getitem_func @@ -420,14 +516,14 @@ def properties(self, value): def __getitem__(self, idx): if self._getitem is None: - raise AttributeError('`LazyStreamlines` does not support indexing.') + raise AttributeError('`LazyTractogram` does not support indexing.') return self._getitem(idx) def __iter__(self): i = 0 for i, s in enumerate(self._data(), start=1): - yield Streamline(*s) + yield TractogramItem(*s) # To be safe, update information about number of streamlines. self.header.nb_streamlines = i @@ -440,14 +536,14 @@ def __len__(self): " streamlines, you might want to set it beforehand via" " `self.header.nb_streamlines`." " Note this will consume any generators used to create this" - " `LazyStreamlines` object.", UsageWarning) + " `LazyTractogram` object.", UsageWarning) return sum(1 for _ in self) return self.header.nb_streamlines def copy(self): - """ Returns a copy of this `LazyStreamlines` object. """ - streamlines = LazyStreamlines(self._points, self._scalars, self._properties) + """ Returns a copy of this `LazyTractogram` object. """ + streamlines = LazyTractogram(self._points, self._scalars, self._properties) streamlines._header = self.header.copy() return streamlines @@ -461,23 +557,23 @@ def transform(self, affine): Returns ------- - streamlines : `LazyStreamlines` object - Streamlines living in a space defined by `affine`. + streamlines : `LazyTractogram` object + Tractogram living in a space defined by `affine`. """ - return super(LazyStreamlines, self).transform(affine, lazy=True) + return super(LazyTractogram, self).transform(affine, lazy=True) def to_world_space(self): """ Sends the streamlines back into world space. Returns ------- - streamlines : `LazyStreamlines` object - Streamlines living in world space. + streamlines : `LazyTractogram` object + Tractogram living in world space. """ - return super(LazyStreamlines, self).to_world_space(lazy=True) + return super(LazyTractogram, self).to_world_space(lazy=True) -class StreamlinesFile: +class TractogramFile: ''' Convenience class to encapsulate streamlines file format. ''' @classmethod @@ -533,9 +629,9 @@ def load(fileobj, ref, lazy_load=True): Returns ------- - streamlines : Streamlines object + streamlines : Tractogram object Returns an object containing streamlines' data and header - information. See 'nibabel.Streamlines'. + information. See 'nibabel.Tractogram'. ''' raise NotImplementedError() @@ -545,9 +641,9 @@ def save(streamlines, fileobj, ref=None): Parameters ---------- - streamlines : Streamlines object + streamlines : Tractogram object Object containing streamlines' data and header information. - See 'nibabel.Streamlines'. + See 'nibabel.Tractogram'. fileobj : string or file-like object If string, a filename; otherwise an open file-like object @@ -577,7 +673,7 @@ def pretty_print(streamlines): raise NotImplementedError() -# class DynamicStreamlineFile(StreamlinesFile): +# class DynamicTractogramFile(TractogramFile): # ''' Convenience class to encapsulate streamlines file format # that supports appending streamlines to an existing file. # ''' diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 7a7ec63b3c..706afd348a 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -24,7 +24,7 @@ class Field: ENDIAN = "endian" -class StreamlinesHeader(object): +class TractogramHeader(object): def __init__(self): self._nb_streamlines = None self._nb_scalars_per_point = None @@ -90,7 +90,7 @@ def extra(self, value): self._extra = OrderedDict(value) def copy(self): - header = StreamlinesHeader() + header = TractogramHeader() header._nb_streamlines = self.nb_streamlines header.nb_scalars_per_point = self.nb_scalars_per_point header.nb_properties_per_streamline = self.nb_properties_per_streamline diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index f7164eca5e..b699066c1c 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -7,16 +7,232 @@ from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal, assert_array_almost_equal -from nibabel.externals.six.moves import zip +from nibabel.externals.six.moves import zip, zip_longest from .. import base_format -from ..base_format import Streamlines, LazyStreamlines +from ..base_format import CompactList +from ..base_format import TractogramItem, Tractogram, LazyTractogram from ..base_format import UsageWarning DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') -class TestStreamlines(unittest.TestCase): +class TestCompactList(unittest.TestCase): + + def setUp(self): + rng = np.random.RandomState(42) + self.data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + self.lengths = map(len, self.data) + self.clist = CompactList(self.data) + + def test_creating_empty_compactlist(self): + clist = CompactList() + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_list(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = map(len, data) + + clist = CompactList(data) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + + # Empty list + clist = CompactList([]) + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_generator(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = map(len, data) + + gen = (e for e in data) + clist = CompactList(gen) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + + # Already consumed generator + clist = CompactList(gen) + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_compact_list(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = map(len, data) + + clist = CompactList(data) + clist2 = CompactList(clist) + assert_equal(len(clist2), len(data)) + assert_equal(len(clist2._offsets), len(data)) + assert_equal(len(clist2._lengths), len(data)) + assert_equal(clist2._data.shape[0], sum(lengths)) + assert_equal(clist2._data.shape[1], 3) + assert_equal(clist2._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist2._lengths, lengths) + assert_equal(clist2.shape, data[0].shape[1:]) + + def test_compactlist_iter(self): + for e, d in zip(self.clist, self.data): + assert_array_equal(e, d) + + def test_compactlist_copy(self): + clist = self.clist.copy() + assert_array_equal(clist._data, self.clist._data) + assert_true(clist._data is not self.clist._data) + assert_array_equal(clist._offsets, self.clist._offsets) + assert_true(clist._offsets is not self.clist._offsets) + assert_array_equal(clist._lengths, self.clist._lengths) + assert_true(clist._lengths is not self.clist._lengths) + + assert_equal(clist.shape, self.clist.shape) + + # When taking a copy of a `CompactList` generated by slicing. + # Only needed data should be kept. + clist = self.clist[::2].copy() + + assert_true(clist._data.shape[0] < self.clist._data.shape[0]) + assert_true(len(clist) < len(self.clist)) + assert_true(clist._data is not self.clist._data) + + def test_compactlist_append(self): + # Maybe not necessary if `self.setUp` is always called before a + # test method, anyways create a copy just in case. + clist = self.clist.copy() + + rng = np.random.RandomState(1234) + element = rng.rand(rng.randint(10, 50), *self.clist.shape) + clist.append(element) + assert_equal(len(clist), len(self.clist)+1) + assert_equal(clist._offsets[-1], len(self.clist._data)) + assert_equal(clist._lengths[-1], len(element)) + assert_array_equal(clist._data[-len(element):], element) + + # Append with different shape. + element = rng.rand(rng.randint(10, 50), 42) + assert_raises(ValueError, clist.append, element) + + # Append to an empty CompactList. + clist = CompactList() + rng = np.random.RandomState(1234) + shape = (2, 3, 4) + element = rng.rand(rng.randint(10, 50), *shape) + clist.append(element) + + assert_equal(len(clist), 1) + assert_equal(clist._offsets[-1], 0) + assert_equal(clist._lengths[-1], len(element)) + assert_array_equal(clist._data, element) + assert_equal(clist.shape, shape) + + def test_compactlist_extend(self): + # Maybe not necessary if `self.setUp` is always called before a + # test method, anyways create a copy just in case. + clist = self.clist.copy() + + rng = np.random.RandomState(1234) + shape = self.clist.shape + new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(5)] + lengths = map(len, new_data) + clist.extend(new_data) + assert_equal(len(clist), len(self.clist)+len(new_data)) + assert_array_equal(clist._offsets[-len(new_data):], + len(self.clist._data) + np.cumsum([0] + lengths[:-1])) + + assert_equal(clist._lengths[-len(new_data):], lengths) + assert_array_equal(clist._data[-sum(lengths):], + np.concatenate(new_data, axis=0)) + + # Extend with another `CompactList` object. + clist = self.clist.copy() + new_data = CompactList(new_data) + clist.extend(new_data) + assert_equal(len(clist), len(self.clist)+len(new_data)) + assert_array_equal(clist._offsets[-len(new_data):], + len(self.clist._data) + np.cumsum([0] + lengths[:-1])) + + assert_equal(clist._lengths[-len(new_data):], lengths) + assert_array_equal(clist._data[-sum(lengths):], new_data._data) + + def test_compactlist_getitem(self): + # Get one item + for i, e in enumerate(self.clist): + assert_array_equal(self.clist[i], e) + + # Get multiple items (this will create a view). + clist_view = self.clist[range(len(self.clist))] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_true(clist_view._offsets is not self.clist._offsets) + assert_true(clist_view._lengths is not self.clist._lengths) + assert_array_equal(clist_view._offsets, self.clist._offsets) + assert_array_equal(clist_view._lengths, self.clist._lengths) + for e1, e2 in zip_longest(clist_view, self.clist): + assert_array_equal(e1, e2) + + # Get slice (this will create a view). + clist_view = self.clist[::2] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_array_equal(clist_view._offsets, self.clist._offsets[::2]) + assert_array_equal(clist_view._lengths, self.clist._lengths[::2]) + for i, e in enumerate(clist_view): + assert_array_equal(e, self.clist[i*2]) + + +class TestTractogramItem(unittest.TestCase): + + def test_creating_tractogram_item(self): + rng = np.random.RandomState(42) + points = rng.rand(rng.randint(10, 50), 3) + scalars = rng.rand(len(points), 5) + properties = rng.rand(42) + + # Create a streamline with only points + s = TractogramItem(points) + assert_equal(len(s), len(points)) + assert_array_equal(s.scalars, []) + assert_array_equal(s.properties, []) + + # Create a streamline with points, scalars and properties. + s = TractogramItem(points, scalars, properties) + assert_equal(len(s), len(points)) + assert_array_equal(s.points, points) + assert_array_equal(list(s), points) + assert_equal(len(s), len(scalars)) + assert_array_equal(s.scalars, scalars) + assert_array_equal(s.properties, properties) + + # # Create a streamline with different number of scalars. + # scalars = rng.rand(len(points)+3, 5) + # assert_raises(ValueError, TractogramItem, points, scalars) + + +class TestTractogram(unittest.TestCase): def setUp(self): self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") @@ -41,7 +257,7 @@ def setUp(self): def test_streamlines_creation_from_arrays(self): # Empty - streamlines = Streamlines() + streamlines = Tractogram() assert_equal(len(streamlines), 0) assert_arrays_equal(streamlines.points, []) assert_arrays_equal(streamlines.scalars, []) @@ -52,7 +268,7 @@ def test_streamlines_creation_from_arrays(self): pass # Only points - streamlines = Streamlines(points=self.points) + streamlines = Tractogram(points=self.points) assert_equal(len(streamlines), len(self.points)) assert_arrays_equal(streamlines.points, self.points) assert_arrays_equal(streamlines.scalars, []) @@ -63,7 +279,7 @@ def test_streamlines_creation_from_arrays(self): pass # Points, scalars and properties - streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) + streamlines = Tractogram(self.points, self.colors, self.mean_curvature_torsion) assert_equal(len(streamlines), len(self.points)) assert_arrays_equal(streamlines.points, self.points) assert_arrays_equal(streamlines.scalars, self.colors) @@ -73,11 +289,11 @@ def test_streamlines_creation_from_arrays(self): for streamline in streamlines: pass - #streamlines = Streamlines(self.points, scalars) + #streamlines = Tractogram(self.points, scalars) def test_streamlines_getter(self): - # Streamlines with only points - streamlines = Streamlines(points=self.points) + # Tractogram with only points + streamlines = Tractogram(points=self.points) selected_streamlines = streamlines[::2] assert_equal(len(selected_streamlines), (len(self.points)+1)//2) @@ -86,8 +302,8 @@ def test_streamlines_getter(self): assert_equal(sum(map(len, selected_streamlines.scalars)), 0) assert_equal(sum(map(len, selected_streamlines.properties)), 0) - # Streamlines with points, scalars and properties - streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) + # Tractogram with points, scalars and properties + streamlines = Tractogram(self.points, self.colors, self.mean_curvature_torsion) # Retrieve streamlines by their index for i, streamline in enumerate(streamlines): @@ -107,23 +323,12 @@ def test_streamlines_creation_from_coroutines(self): scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - # To create streamlines from coroutines use `LazyStreamlines`. - assert_raises(TypeError, Streamlines, points, scalars, properties) - - def test_to_world_space(self): - streamlines = Streamlines(self.points) - - # World space is (RAS+) with voxel size of 2x3x4mm. - streamlines.header.voxel_sizes = (2, 3, 4) - - new_streamlines = streamlines.to_world_space() - for new_pts, pts in zip(new_streamlines.points, self.points): - for dim, size in enumerate(streamlines.header.voxel_sizes): - assert_array_almost_equal(new_pts[:, dim], size*pts[:, dim]) + # To create streamlines from coroutines use `LazyTractogram`. + assert_raises(TypeError, Tractogram, points, scalars, properties) def test_header(self): - # Empty Streamlines, with default header - streamlines = Streamlines() + # Empty Tractogram, with default header + streamlines = Tractogram() assert_equal(streamlines.header.nb_streamlines, 0) assert_equal(streamlines.header.nb_scalars_per_point, 0) assert_equal(streamlines.header.nb_properties_per_streamline, 0) @@ -131,7 +336,7 @@ def test_header(self): assert_array_equal(streamlines.header.to_world_space, np.eye(4)) assert_equal(streamlines.header.extra, {}) - streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) + streamlines = Tractogram(self.points, self.colors, self.mean_curvature_torsion) assert_equal(streamlines.header.nb_streamlines, len(self.points)) assert_equal(streamlines.header.nb_scalars_per_point, self.colors[0].shape[1]) @@ -151,7 +356,7 @@ def test_header(self): repr(streamlines.header) -class TestLazyStreamlines(unittest.TestCase): +class TestLazyTractogram(unittest.TestCase): def setUp(self): self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") @@ -175,22 +380,22 @@ def setUp(self): self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) def test_lazy_streamlines_creation(self): - # To create streamlines from arrays use `Streamlines`. - assert_raises(TypeError, LazyStreamlines, self.points) + # To create streamlines from arrays use `Tractogram`. + assert_raises(TypeError, LazyTractogram, self.points) # Points, scalars and properties points = (x for x in self.points) scalars = (x for x in self.colors) properties = (x for x in self.mean_curvature_torsion) - # Creating LazyStreamlines from generators is not allowed as + # Creating LazyTractogram from generators is not allowed as # generators get exhausted and are not reusable unline coroutines. - assert_raises(TypeError, LazyStreamlines, points) - assert_raises(TypeError, LazyStreamlines, self.points, scalars) - assert_raises(TypeError, LazyStreamlines, properties_func=properties) + assert_raises(TypeError, LazyTractogram, points) + assert_raises(TypeError, LazyTractogram, self.points, scalars) + assert_raises(TypeError, LazyTractogram, properties_func=properties) - # Empty `LazyStreamlines` - streamlines = LazyStreamlines() + # Empty `LazyTractogram` + streamlines = LazyTractogram() with suppress_warnings(): assert_equal(len(streamlines), 0) assert_arrays_equal(streamlines.points, []) @@ -206,7 +411,7 @@ def test_lazy_streamlines_creation(self): scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - streamlines = LazyStreamlines(points, scalars, properties) + streamlines = LazyTractogram(points, scalars, properties) with suppress_warnings(): assert_equal(len(streamlines), self.nb_streamlines) @@ -218,10 +423,10 @@ def test_lazy_streamlines_creation(self): assert_arrays_equal(streamlines.scalars, self.colors) assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) - # Create `LazyStreamlines` from a coroutine yielding 3-tuples + # Create `LazyTractogram` from a coroutine yielding 3-tuples data = lambda: (x for x in zip(self.points, self.colors, self.mean_curvature_torsion)) - streamlines = LazyStreamlines.create_from_data(data) + streamlines = LazyTractogram.create_from_data(data) with suppress_warnings(): assert_equal(len(streamlines), self.nb_streamlines) assert_arrays_equal(streamlines.points, self.points) @@ -237,18 +442,18 @@ def test_lazy_streamlines_indexing(self): scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - # By default, `LazyStreamlines` object does not support indexing. - streamlines = LazyStreamlines(points, scalars, properties) + # By default, `LazyTractogram` object does not support indexing. + streamlines = LazyTractogram(points, scalars, properties) assert_raises(AttributeError, streamlines.__getitem__, 0) - # Create a `LazyStreamlines` object with indexing support. + # Create a `LazyTractogram` object with indexing support. def getitem_without_properties(idx): if isinstance(idx, int) or isinstance(idx, np.integer): return self.points[idx], self.colors[idx] return list(zip(self.points[idx], self.colors[idx])) - streamlines = LazyStreamlines(points, scalars, properties, getitem_without_properties) + streamlines = LazyTractogram(points, scalars, properties, getitem_without_properties) points, scalars = streamlines[0] assert_array_equal(points, self.points[0]) assert_array_equal(scalars, self.colors[0]) @@ -270,12 +475,12 @@ def test_lazy_streamlines_len(self): warnings.simplefilter("always") # Always trigger warnings. # Calling `len` will create new generators each time. - streamlines = LazyStreamlines(points, scalars, properties) + streamlines = LazyTractogram(points, scalars, properties) # This should produce a warning message. assert_equal(len(streamlines), self.nb_streamlines) assert_equal(len(w), 1) - streamlines = LazyStreamlines(points, scalars, properties) + streamlines = LazyTractogram(points, scalars, properties) # This should still produce a warning message. assert_equal(len(streamlines), self.nb_streamlines) assert_equal(len(w), 2) @@ -287,7 +492,7 @@ def test_lazy_streamlines_len(self): with clear_and_catch_warnings(record=True, modules=[base_format]) as w: # Once we iterated through the streamlines, we know the length. - streamlines = LazyStreamlines(points, scalars, properties) + streamlines = LazyTractogram(points, scalars, properties) assert_true(streamlines.header.nb_streamlines is None) for streamline in streamlines: pass @@ -299,15 +504,15 @@ def test_lazy_streamlines_len(self): with clear_and_catch_warnings(record=True, modules=[base_format]) as w: # It first checks if number of streamlines is in the header. - streamlines = LazyStreamlines(points, scalars, properties) + streamlines = LazyTractogram(points, scalars, properties) streamlines.header.nb_streamlines = 1234 # This should *not* produce a warning. assert_equal(len(streamlines), 1234) assert_equal(len(w), 0) def test_lazy_streamlines_header(self): - # Empty `LazyStreamlines`, with default header - streamlines = LazyStreamlines() + # Empty `LazyTractogram`, with default header + streamlines = LazyTractogram() assert_true(streamlines.header.nb_streamlines is None) assert_equal(streamlines.header.nb_scalars_per_point, 0) assert_equal(streamlines.header.nb_properties_per_streamline, 0) @@ -318,7 +523,7 @@ def test_lazy_streamlines_header(self): points = lambda: (x for x in self.points) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - streamlines = LazyStreamlines(points) + streamlines = LazyTractogram(points) header = streamlines.header assert_equal(header.nb_scalars_per_point, 0) diff --git a/nibabel/streamlines/tests/test_header.py b/nibabel/streamlines/tests/test_header.py index a36a257818..398195f615 100644 --- a/nibabel/streamlines/tests/test_header.py +++ b/nibabel/streamlines/tests/test_header.py @@ -3,11 +3,11 @@ from nose.tools import assert_equal, assert_true from numpy.testing import assert_array_equal -from nibabel.streamlines.header import StreamlinesHeader +from nibabel.streamlines.header import TractogramHeader def test_streamlines_header(): - header = StreamlinesHeader() + header = TractogramHeader() assert_true(header.nb_streamlines is None) assert_true(header.nb_scalars_per_point is None) assert_true(header.nb_properties_per_streamline is None) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index e6ab8cab17..1e94b79911 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -12,7 +12,7 @@ from nibabel.testing import assert_arrays_equal from nose.tools import assert_equal, assert_raises, assert_true, assert_false -from ..base_format import Streamlines, LazyStreamlines +from ..base_format import Tractogram, LazyTractogram from ..base_format import HeaderError, UsageWarning from ..header import Field from .. import trk @@ -147,7 +147,7 @@ def test_load_empty_file(self): streamlines = nib.streamlines.load(empty_filename, ref=self.to_world_space, lazy_load=False) - assert_true(type(streamlines), Streamlines) + assert_true(type(streamlines), Tractogram) check_streamlines(streamlines, 0, [], [], []) def test_load_simple_file(self): @@ -155,7 +155,7 @@ def test_load_simple_file(self): streamlines = nib.streamlines.load(simple_filename, ref=self.to_world_space, lazy_load=False) - assert_true(type(streamlines), Streamlines) + assert_true(type(streamlines), Tractogram) check_streamlines(streamlines, self.nb_streamlines, self.points, [], []) @@ -163,7 +163,7 @@ def test_load_simple_file(self): streamlines = nib.streamlines.load(simple_filename, ref=self.to_world_space, lazy_load=True) - assert_true(type(streamlines), LazyStreamlines) + assert_true(type(streamlines), LazyTractogram) check_streamlines(streamlines, self.nb_streamlines, self.points, [], []) @@ -182,7 +182,7 @@ def test_load_complex_file(self): streamlines = nib.streamlines.load(complex_filename, ref=self.to_world_space, lazy_load=False) - assert_true(type(streamlines), Streamlines) + assert_true(type(streamlines), Tractogram) check_streamlines(streamlines, self.nb_streamlines, self.points, scalars, properties) @@ -190,12 +190,12 @@ def test_load_complex_file(self): streamlines = nib.streamlines.load(complex_filename, ref=self.to_world_space, lazy_load=True) - assert_true(type(streamlines), LazyStreamlines) + assert_true(type(streamlines), LazyTractogram) check_streamlines(streamlines, self.nb_streamlines, self.points, scalars, properties) def test_save_simple_file(self): - streamlines = Streamlines(self.points) + streamlines = Tractogram(self.points) for ext in nib.streamlines.FORMATS.keys(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: nib.streamlines.save(streamlines, f.name) @@ -204,7 +204,7 @@ def test_save_simple_file(self): self.points, [], []) def test_save_complex_file(self): - streamlines = Streamlines(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) + streamlines = Tractogram(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: with clear_and_catch_warnings(record=True, modules=[trk]) as w: diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index e29b288dc5..6013fe5507 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -9,7 +9,7 @@ from nose.tools import assert_equal, assert_raises, assert_true from .. import base_format -from ..base_format import Streamlines, LazyStreamlines +from ..base_format import Tractogram, LazyTractogram from ..base_format import DataError, HeaderError, HeaderWarning, UsageWarning from .. import trk @@ -128,7 +128,7 @@ def test_load_file_with_wrong_information(self): assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) def test_write_simple_file(self): - streamlines = Streamlines(self.points) + streamlines = Tractogram(self.points) trk_file = BytesIO() TrkFile.save(streamlines, trk_file) @@ -146,7 +146,7 @@ def test_write_simple_file(self): def test_write_complex_file(self): # With scalars - streamlines = Streamlines(self.points, scalars=self.colors) + streamlines = Tractogram(self.points, scalars=self.colors) trk_file = BytesIO() TrkFile.save(streamlines, trk_file) @@ -158,7 +158,7 @@ def test_write_complex_file(self): self.points, self.colors, []) # With properties - streamlines = Streamlines(self.points, properties=self.mean_curvature_torsion) + streamlines = Tractogram(self.points, properties=self.mean_curvature_torsion) trk_file = BytesIO() TrkFile.save(streamlines, trk_file) @@ -169,7 +169,7 @@ def test_write_complex_file(self): self.points, [], self.mean_curvature_torsion) # With scalars and properties - streamlines = Streamlines(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) + streamlines = Tractogram(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) trk_file = BytesIO() TrkFile.save(streamlines, trk_file) @@ -191,14 +191,14 @@ def test_write_erroneous_file(self): [(0, 1, 0)], [(0, 0, 1)]] - streamlines = Streamlines(self.points, scalars) + streamlines = Tractogram(self.points, scalars) assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) # No scalars for every streamlines scalars = [[(1, 0, 0)]*1, [(0, 1, 0)]*2] - streamlines = Streamlines(self.points, scalars) + streamlines = Tractogram(self.points, scalars) assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) # # Unit test moved to test_base_format.py @@ -207,7 +207,7 @@ def test_write_erroneous_file(self): # [(0, 1, 0), (0, 1)], # [(0, 0, 1)]*5] - #streamlines = Streamlines(self.points, scalars) + #streamlines = Tractogram(self.points, scalars) #assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) # # Unit test moved to test_base_format.py @@ -216,7 +216,7 @@ def test_write_erroneous_file(self): # [(0, 1)]*2, # [(0, 0, 1)]*5] - # streamlines = Streamlines(self.points, scalars) + # streamlines = Tractogram(self.points, scalars) # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) # # Unit test moved to test_base_format.py @@ -224,14 +224,14 @@ def test_write_erroneous_file(self): # properties = [np.array([1.11, 1.22], dtype="f4"), # np.array([2.11], dtype="f4"), # np.array([3.11, 3.22], dtype="f4")] - # streamlines = Streamlines(self.points, properties=properties) + # streamlines = Tractogram(self.points, properties=properties) # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) # # Unit test moved to test_base_format.py # No properties for every streamlines properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11, 2.22], dtype="f4")] - streamlines = Streamlines(self.points, properties=properties) + streamlines = Tractogram(self.points, properties=properties) assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) def test_write_file_lazy_streamlines(self): @@ -239,7 +239,7 @@ def test_write_file_lazy_streamlines(self): scalars = lambda: (scalar for scalar in self.colors) properties = lambda: (prop for prop in self.mean_curvature_torsion) - streamlines = LazyStreamlines(points, scalars, properties) + streamlines = LazyTractogram(points, scalars, properties) # No need to manually set `nb_streamlines` in the header since we count # them as writing. #streamlines.header.nb_streamlines = self.nb_streamlines diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 70df0385d0..53087ddad5 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -13,9 +13,9 @@ from nibabel.volumeutils import (native_code, swapped_code) from nibabel.streamlines.base_format import CompactList -from nibabel.streamlines.base_format import StreamlinesFile +from nibabel.streamlines.base_format import TractogramFile from nibabel.streamlines.base_format import DataError, HeaderError, HeaderWarning -from nibabel.streamlines.base_format import Streamlines, LazyStreamlines +from nibabel.streamlines.base_format import Tractogram, LazyTractogram from nibabel.streamlines.header import Field from nibabel.streamlines.utils import get_affine_from_reference @@ -140,15 +140,15 @@ def __init__(self, fileobj): # Keep the file position where the data begin. self.offset_data = f.tell() - if f.name is not None and self.header[Field.NB_STREAMLINES] > 0: - filesize = os.path.getsize(f.name) - self.offset_data - # Remove properties - filesize -= self.header[Field.NB_STREAMLINES] * self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * 4. - # Remove the points count at the beginning of each streamline. - filesize -= self.header[Field.NB_STREAMLINES] * 4. - # Get nb points. - nb_points = filesize / ((3 + self.header[Field.NB_SCALARS_PER_POINT]) * 4.) - self.header[Field.NB_POINTS] = int(nb_points) + # if f.name is not None and self.header[Field.NB_STREAMLINES] > 0: + # filesize = os.path.getsize(f.name) - self.offset_data + # # Remove properties + # filesize -= self.header[Field.NB_STREAMLINES] * self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * 4. + # # Remove the points count at the beginning of each streamline. + # filesize -= self.header[Field.NB_STREAMLINES] * 4. + # # Get nb points. + # nb_points = filesize / ((3 + self.header[Field.NB_SCALARS_PER_POINT]) * 4.) + # self.header[Field.NB_POINTS] = int(nb_points) def __iter__(self): i4_dtype = np.dtype(self.endianness + "i4") @@ -162,10 +162,10 @@ def __iter__(self): nb_pts_and_scalars = int(3 + self.header[Field.NB_SCALARS_PER_POINT]) pts_and_scalars_size = int(nb_pts_and_scalars * f4_dtype.itemsize) - slice_pts_and_scalars = lambda data: (data, []) - if self.header[Field.NB_SCALARS_PER_POINT] > 0: - # This is faster than `np.split` - slice_pts_and_scalars = lambda data: (data[:, :3], data[:, 3:]) + #slice_pts_and_scalars = lambda data: (data, [[]]) + #if self.header[Field.NB_SCALARS_PER_POINT] > 0: + # This is faster than `np.split` + slice_pts_and_scalars = lambda data: (data[:, :3], data[:, 3:]) # Using np.fromfile would be faster, but does not support StringIO read_pts_and_scalars = lambda nb_pts: slice_pts_and_scalars(np.ndarray(shape=(nb_pts, nb_pts_and_scalars), @@ -202,14 +202,6 @@ def __iter__(self): pts, scalars = read_pts_and_scalars(nb_pts) properties = read_properties() - # TRK's streamlines are in 'voxelmm' space, we send them to voxel space. - pts = pts / self.header[Field.VOXEL_SIZES] - # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas streamlines returned assume (0,0,0) to be the - # center of the voxel. Thus, streamlines are shifted of half - #a voxel. - pts -= np.array(self.header[Field.VOXEL_SIZES])/2. - yield pts, scalars, properties i += 1 @@ -317,7 +309,7 @@ def write(self, streamlines): self.file.write(self.header[0].tostring()) -class TrkFile(StreamlinesFile): +class TrkFile(TractogramFile): ''' Convenience class to encapsulate TRK file format. Note @@ -393,80 +385,96 @@ def load(fileobj, ref=None, lazy_load=False): Returns ------- - streamlines : Streamlines object + streamlines : Tractogram object Returns an object containing streamlines' data and header - information. See `nibabel.Streamlines`. + information. See `nibabel.Tractogram`. Notes ----- - Streamlines are assumed to be in voxel space where coordinate (0,0,0) + Tractogram are assumed to be in voxel space where coordinate (0,0,0) refers to the center of the voxel. ''' trk_reader = TrkReader(fileobj) - # Check if reference space matches one from TRK's header. + # TRK's streamlines are in 'voxelmm' space, we send them to rasmm. affine = trk_reader.header[Field.to_world_space] - if ref is not None: - affine = get_affine_from_reference(ref) - if not np.allclose(affine, trk_reader.header[Field.to_world_space]): - raise ValueError("Reference space provided does not match the " - " one from the TRK file header. Use `ref=None`" - " to use one contained in the TRK file") + affine[range(3), range(3)] /= trk_reader.header[Field.VOXEL_SIZES] - #points = lambda: (x[0] for x in trk_reader) - #scalars = lambda: (x[1] for x in trk_reader) - #properties = lambda: (x[2] for x in trk_reader) - data = lambda: iter(trk_reader) + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas streamlines returned assume (0,0,0) to be the + # center of the voxel. Thus, streamlines are shifted of half + #a voxel. + affine[:-1, -1] -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. if lazy_load: - streamlines = LazyStreamlines.create_from_data(data) + def _apply_transform(trk_reader): + for pts, scals, props in trk_reader: + # TRK's streamlines are in 'voxelmm' space, we send them to voxel space. + pts = pts / trk_reader.header[Field.VOXEL_SIZES] + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas streamlines returned assume (0,0,0) to be the + # center of the voxel. Thus, streamlines are shifted of half + #a voxel. + pts -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. + trk_reader + yield pts, scals, props + + data = lambda: _apply_transform(trk_reader) + streamlines = LazyTractogram.create_from_data(data) # Overwrite scalars and properties if there is none if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: streamlines.scalars = lambda: [] if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: streamlines.properties = lambda: [] - elif Field.NB_POINTS in trk_reader.header: - # 'count' field is provided, we can avoid creating list of numpy - # arrays (more memory efficient). - - nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] - nb_points = trk_reader.header[Field.NB_POINTS] - - points = CompactList(preallocate=(nb_points, 3)) - scalars = CompactList(preallocate=(nb_points, trk_reader.header[Field.NB_SCALARS_PER_POINT])) - properties = np.empty((nb_streamlines, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]), - dtype=np.float32) - - offset = 0 - offsets = [] - lengths = [] - for i, (pts, scals, props) in enumerate(data()): - offsets.append(offset) - lengths.append(len(pts)) - try: - points.data[offset:offset+len(pts)] = pts - except: - from ipdb import set_trace as dbg - dbg() - scalars.data[offset:offset+len(scals)] = scals - properties[i] = props - offset += len(pts) - - points.offsets = offsets - scalars.offsets = offsets - points.lengths = lengths - scalars.lengths = lengths - streamlines = Streamlines(points, scalars, properties) - # Overwrite scalars and properties if there is none - if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: - streamlines.scalars = [] - if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: - streamlines.properties = [] + # elif Field.NB_POINTS in trk_reader.header: + # # 'count' field is provided, we can avoid creating list of numpy + # # arrays (more memory efficient). + + # nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] + # nb_points = trk_reader.header[Field.NB_POINTS] + + # points = CompactList() + # points._data = np.empty((nb_points, 3), dtype=np.float32) + + # scalars = CompactList() + # scalars._data = np.empty((nb_points, trk_reader.header[Field.NB_SCALARS_PER_POINT]), + # dtype=np.float32) + + # properties = CompactList() + # properties._data = np.empty((nb_streamlines, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]), + # dtype=np.float32) + + # offset = 0 + # offsets = [] + # lengths = [] + # for i, (pts, scals, props) in enumerate(trk_reader): + # offsets.append(offset) + # lengths.append(len(pts)) + # points._data[offset:offset+len(pts)] = pts + # scalars._data[offset:offset+len(scals)] = scals + # properties._data[i] = props + # offset += len(pts) + + # points.offsets = offsets + # scalars.offsets = offsets + # points.lengths = lengths + # scalars.lengths = lengths + + # streamlines = Tractogram(points, scalars, properties) + # streamlines.apply_affine(affine) + + # # Overwrite scalars and properties if there is none + # if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: + # streamlines.scalars = [] + # if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: + # streamlines.properties = [] else: - streamlines = Streamlines(*zip(*data())) + streamlines = Tractogram.create_from_generator(trk_reader) + #streamlines = Tractogram(*zip(*trk_reader)) + streamlines.apply_affine(affine) # Overwrite scalars and properties if there is none if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: @@ -502,9 +510,9 @@ def save(streamlines, fileobj, ref=None): Parameters ---------- - streamlines : Streamlines object + streamlines : Tractogram object Object containing streamlines' data and header information. - See 'nibabel.Streamlines'. + See 'nibabel.Tractogram'. fileobj : string or file-like object If string, a filename; otherwise an open file-like object @@ -516,7 +524,7 @@ def save(streamlines, fileobj, ref=None): Notes ----- - Streamlines are assumed to be in voxel space where coordinate (0,0,0) + Tractogram are assumed to be in voxel space where coordinate (0,0,0) refers to the center of the voxel. ''' if ref is not None: From a37ff5442c79e32c0e6f3b86bb2789654e867609 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 30 Oct 2015 23:21:38 -0400 Subject: [PATCH 13/54] Moves some unit tests --- nibabel/streamlines/base_format.py | 11 ++++++++++ nibabel/streamlines/tests/test_base_format.py | 16 ++++++++++++++- nibabel/streamlines/tests/test_trk.py | 20 +++++++++---------- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 95fc2669b5..51ccf31f68 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -276,6 +276,9 @@ def create_from_generator(cls, gen): scals = np.asarray(first_element[1]) props = np.asarray(first_element[2]) + scals_shape = scals.shape + props_shape = props.shape + points._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) scalars._data = np.empty((BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) properties = np.empty((BUFFER_SIZE, props.shape[0]), dtype=props.dtype) @@ -286,6 +289,14 @@ def create_from_generator(cls, gen): scals = np.asarray(scals) props = np.asarray(props) + if scals.shape[1] != scals_shape[1]: + raise ValueError("Number of scalars differs from one" + " point or streamline to another") + + if props.shape != props_shape: + raise ValueError("Number of properties differs from one" + " streamline to another") + end = offset + len(pts) if end >= len(points._data): # Resize is needed (at least `len(pts)` items will be added). diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index b699066c1c..23dbd425b3 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -289,7 +289,20 @@ def test_streamlines_creation_from_arrays(self): for streamline in streamlines: pass - #streamlines = Tractogram(self.points, scalars) + # Inconsistent number of scalars between points + scalars = [[(1, 0, 0)]*1, + [(0, 1, 0), (0, 1)], + [(0, 0, 1)]*5] + + assert_raises(ValueError, Tractogram, self.points, scalars) + + # Unit test moved to test_base_format.py + # Inconsistent number of scalars between streamlines + scalars = [[(1, 0, 0)]*1, + [(0, 1)]*2, + [(0, 0, 1)]*5] + + assert_raises(ValueError, Tractogram, self.points, scalars) def test_streamlines_getter(self): # Tractogram with only points @@ -356,6 +369,7 @@ def test_header(self): repr(streamlines.header) + class TestLazyTractogram(unittest.TestCase): def setUp(self): diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 6013fe5507..0c5d58aadf 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -207,8 +207,8 @@ def test_write_erroneous_file(self): # [(0, 1, 0), (0, 1)], # [(0, 0, 1)]*5] - #streamlines = Tractogram(self.points, scalars) - #assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) + # streamlines = Tractogram(self.points, scalars) + # assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) # # Unit test moved to test_base_format.py # # Inconsistent number of scalars between streamlines @@ -219,15 +219,15 @@ def test_write_erroneous_file(self): # streamlines = Tractogram(self.points, scalars) # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) - # # Unit test moved to test_base_format.py - # # Inconsistent number of properties - # properties = [np.array([1.11, 1.22], dtype="f4"), - # np.array([2.11], dtype="f4"), - # np.array([3.11, 3.22], dtype="f4")] - # streamlines = Tractogram(self.points, properties=properties) - # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + # Unit test moved to test_base_format.py + # Inconsistent number of properties + properties = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + streamlines = Tractogram(self.points, properties=properties) + assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) - # # Unit test moved to test_base_format.py + # Unit test moved to test_base_format.py # No properties for every streamlines properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11, 2.22], dtype="f4")] From 21f783a514c4c78911c72f5150db8e9433bb7ec5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 2 Nov 2015 08:30:39 -0500 Subject: [PATCH 14/54] Reduced memory usage. --- nibabel/streamlines/base_format.py | 22 ++++++++---- nibabel/streamlines/tests/test_base_format.py | 18 +++++++++- nibabel/streamlines/trk.py | 36 ++++++++----------- 3 files changed, 47 insertions(+), 29 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 51ccf31f68..060395f49b 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -258,7 +258,7 @@ def __init__(self, points=None, scalars=None, properties=None): @classmethod def create_from_generator(cls, gen): - BUFFER_SIZE = 1000 + BUFFER_SIZE = 1000000 points = CompactList() scalars = CompactList() @@ -301,7 +301,7 @@ def create_from_generator(cls, gen): if end >= len(points._data): # Resize is needed (at least `len(pts)` items will be added). points._data.resize((len(points._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) - scalars._data.resize((len(scalars._data) + len(scalars)+BUFFER_SIZE, scals.shape[1])) + scalars._data.resize((len(scalars._data) + len(scals)+BUFFER_SIZE, scals.shape[1])) points._offsets.append(offset) points._lengths.append(len(pts)) @@ -315,14 +315,24 @@ def create_from_generator(cls, gen): properties[i] = props + # Clear unused memory. + points._data.resize((offset, pts.shape[1])) + + if scals_shape[1] == 0: + # Because resizing an empty ndarray creates memory! + scalars._data = np.empty((offset, scals.shape[1])) + else: + scalars._data.resize((offset, scals.shape[1])) + # Share offsets and lengths between points and scalars. scalars._offsets = points._offsets scalars._lengths = points._lengths - # Clear unused memory. - points._data.resize((offset, pts.shape[1])) - scalars._data.resize((offset, scals.shape[1])) - properties.resize((i+1, props.shape[0])) + if props_shape[0] == 0: + # Because resizing an empty ndarray creates memory! + properties = np.empty((i+1, props.shape[0])) + else: + properties.resize((i+1, props.shape[0])) return cls(points, scalars, properties) diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index 23dbd425b3..1aa650489e 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -330,13 +330,29 @@ def test_streamlines_getter(self): assert_arrays_equal(r_streamlines.scalars, self.colors[::-1]) assert_arrays_equal(r_streamlines.properties, self.mean_curvature_torsion[::-1]) + def test_streamlines_creation_from_generator(self): + # Create `Tractogram` from a generator yielding 3-tuples + gen = (x for x in zip(self.points, self.colors, self.mean_curvature_torsion)) + + streamlines = Tractogram.create_from_generator(gen) + with suppress_warnings(): + assert_equal(len(streamlines), self.nb_streamlines) + + assert_arrays_equal(streamlines.points, self.points) + assert_arrays_equal(streamlines.scalars, self.colors) + assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for streamline in streamlines: + pass + def test_streamlines_creation_from_coroutines(self): # Points, scalars and properties points = lambda: (x for x in self.points) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - # To create streamlines from coroutines use `LazyTractogram`. + # To create streamlines from multiple coroutines use `LazyTractogram`. assert_raises(TypeError, Tractogram, points, scalars, properties) def test_header(self): diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 53087ddad5..3194218190 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -154,30 +154,12 @@ def __iter__(self): i4_dtype = np.dtype(self.endianness + "i4") f4_dtype = np.dtype(self.endianness + "f4") - #from io import BufferedReader with Opener(self.fileobj) as f: - #f = BufferedReader(f.fobj) start_position = f.tell() nb_pts_and_scalars = int(3 + self.header[Field.NB_SCALARS_PER_POINT]) pts_and_scalars_size = int(nb_pts_and_scalars * f4_dtype.itemsize) - - #slice_pts_and_scalars = lambda data: (data, [[]]) - #if self.header[Field.NB_SCALARS_PER_POINT] > 0: - # This is faster than `np.split` - slice_pts_and_scalars = lambda data: (data[:, :3], data[:, 3:]) - - # Using np.fromfile would be faster, but does not support StringIO - read_pts_and_scalars = lambda nb_pts: slice_pts_and_scalars(np.ndarray(shape=(nb_pts, nb_pts_and_scalars), - dtype=f4_dtype, - buffer=f.read(nb_pts * pts_and_scalars_size))) - properties_size = int(self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize) - read_properties = lambda: [] - if self.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: - read_properties = lambda: np.fromstring(f.read(properties_size), - dtype=f4_dtype, - count=self.header[Field.NB_PROPERTIES_PER_STREAMLINE]) # Set the file position at the beginning of the data. f.seek(self.offset_data, os.SEEK_SET) @@ -188,6 +170,7 @@ def __iter__(self): nb_streamlines = np.inf i = 0 + nb_pts_dtype = i4_dtype.str[:-1] while i < nb_streamlines: nb_pts_str = f.read(i4_dtype.itemsize) @@ -196,13 +179,22 @@ def __iter__(self): break # Read number of points of the next streamline. - nb_pts = struct.unpack(i4_dtype.str[:-1], nb_pts_str)[0] + nb_pts = struct.unpack(nb_pts_dtype, nb_pts_str)[0] # Read streamline's data - pts, scalars = read_pts_and_scalars(nb_pts) - properties = read_properties() + points_and_scalars = np.ndarray(shape=(nb_pts, nb_pts_and_scalars), + dtype=f4_dtype, + buffer=f.read(nb_pts * pts_and_scalars_size)) + + points = points_and_scalars[:, :3] + scalars = points_and_scalars[:, 3:] + + # Read properties + properties = np.ndarray(shape=(self.header[Field.NB_PROPERTIES_PER_STREAMLINE],), + dtype=f4_dtype, + buffer=f.read(properties_size)) - yield pts, scalars, properties + yield points, scalars, properties i += 1 # In case the 'count' field was not provided. From a17a8cfcb72b9dae30d9b6f034efd44bdbca2db1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 2 Nov 2015 11:04:05 -0500 Subject: [PATCH 15/54] Refactored streamlines->tractogram and points->streamlines --- nibabel/streamlines/base_format.py | 220 +++++------ nibabel/streamlines/tests/test_base_format.py | 353 +++++++++--------- nibabel/streamlines/tests/test_streamlines.py | 134 +++---- nibabel/streamlines/tests/test_trk.py | 152 ++++---- nibabel/streamlines/trk.py | 62 +-- nibabel/testing/__init__.py | 14 +- 6 files changed, 462 insertions(+), 473 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 060395f49b..42e34d0a37 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -2,6 +2,8 @@ import numpy as np from warnings import warn +from abc import ABCMeta, abstractmethod, abstractproperty + from nibabel.externals.six.moves import zip_longest from nibabel.affines import apply_affine @@ -44,7 +46,7 @@ def __init__(self, iterable=None): if iterable is not None: # Initialize the `CompactList` object from iterable's item. - BUFFER_SIZE = 1000 + BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. offset = 0 for i, e in enumerate(iterable): @@ -190,15 +192,16 @@ def __len__(self): def __repr__(self): return repr(list(self)) + class TractogramItem(object): ''' Class containing information about one streamline. - ``TractogramItem`` objects have three main properties: `points`, `scalars` + ``TractogramItem`` objects have three main properties: `streamline`, `scalars` and ``properties``. Parameters ---------- - points : ndarray of shape (N, 3) + streamline : ndarray of shape (N, 3) Points of this streamline represented as an ndarray of shape (N, 3) where N is the number of points. @@ -211,56 +214,56 @@ class TractogramItem(object): Properties associated with this streamline and represented as an ndarray of shape (P,) where P is the number of properties. ''' - def __init__(self, points, scalars=None, properties=None): - #if scalars is not None and len(points) != len(scalars): - # raise ValueError("First dimension of points and scalars must match.") + def __init__(self, streamline, scalars=None, properties=None): + #if scalars is not None and len(streamline) != len(scalars): + # raise ValueError("First dimension of streamline and scalars must match.") - self.points = np.asarray(points) + self.streamline = np.asarray(streamline) self.scalars = np.asarray([] if scalars is None else scalars) self.properties = np.asarray([] if properties is None else properties) def __iter__(self): - return iter(self.points) + return iter(self.streamline) def __len__(self): - return len(self.points) + return len(self.streamline) class Tractogram(object): ''' Class containing information about streamlines. - Tractogram objects have three main properties: ``points``, ``scalars`` + Tractogram objects have three main properties: ``streamlines``, ``scalars`` and ``properties``. Tractogram objects can be iterate over producing - tuple of ``points``, ``scalars`` and ``properties`` for each streamline. + tuple of ``streamlines``, ``scalars`` and ``properties`` for each streamline. Parameters ---------- - points : list of ndarray of shape (N, 3) - Sequence of T streamlines. One streamline is an ndarray of shape (N, 3) - where N is the number of points in a streamline. - - scalars : list of ndarray of shape (N, M) - Sequence of T ndarrays of shape (N, M) where T is the number of - streamlines defined by ``points``, N is the number of points - for a particular streamline and M is the number of scalars + streamlines : list of ndarray of shape (Nt, 3) + Sequence of T streamlines. One streamline is an ndarray of shape (Nt, 3) + where Nt is the number of points of streamline t. + + scalars : list of ndarray of shape (Nt, M) + Sequence of T ndarrays of shape (Nt, M) where T is the number of + streamlines defined by ``streamlines``, Nt is the number of points + for a particular streamline t and M is the number of scalars associated to each point (excluding the three coordinates). properties : list of ndarray of shape (P,) Sequence of T ndarrays of shape (P,) where T is the number of - streamlines defined by ``points``, P is the number of properties + streamlines defined by ``streamlines``, P is the number of properties associated to each streamline. ''' - def __init__(self, points=None, scalars=None, properties=None): + def __init__(self, streamlines=None, scalars=None, properties=None): self._header = TractogramHeader() - self.points = points + self.streamlines = streamlines self.scalars = scalars self.properties = properties @classmethod def create_from_generator(cls, gen): - BUFFER_SIZE = 1000000 + BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. - points = CompactList() + streamlines = CompactList() scalars = CompactList() properties = np.array([]) @@ -269,7 +272,7 @@ def create_from_generator(cls, gen): first_element = next(gen) gen = itertools.chain([first_element], gen) except StopIteration: - return cls(points, scalars, properties) + return cls(streamlines, scalars, properties) # Allocated some buffer memory. pts = np.asarray(first_element[0]) @@ -279,7 +282,7 @@ def create_from_generator(cls, gen): scals_shape = scals.shape props_shape = props.shape - points._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) + streamlines._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) scalars._data = np.empty((BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) properties = np.empty((BUFFER_SIZE, props.shape[0]), dtype=props.dtype) @@ -298,14 +301,14 @@ def create_from_generator(cls, gen): " streamline to another") end = offset + len(pts) - if end >= len(points._data): + if end >= len(streamlines._data): # Resize is needed (at least `len(pts)` items will be added). - points._data.resize((len(points._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) + streamlines._data.resize((len(streamlines._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) scalars._data.resize((len(scalars._data) + len(scals)+BUFFER_SIZE, scals.shape[1])) - points._offsets.append(offset) - points._lengths.append(len(pts)) - points._data[offset:offset+len(pts)] = pts + streamlines._offsets.append(offset) + streamlines._lengths.append(len(pts)) + streamlines._data[offset:offset+len(pts)] = pts scalars._data[offset:offset+len(scals)] = scals offset += len(pts) @@ -316,7 +319,7 @@ def create_from_generator(cls, gen): properties[i] = props # Clear unused memory. - points._data.resize((offset, pts.shape[1])) + streamlines._data.resize((offset, pts.shape[1])) if scals_shape[1] == 0: # Because resizing an empty ndarray creates memory! @@ -324,9 +327,9 @@ def create_from_generator(cls, gen): else: scalars._data.resize((offset, scals.shape[1])) - # Share offsets and lengths between points and scalars. - scalars._offsets = points._offsets - scalars._lengths = points._lengths + # Share offsets and lengths between streamlines and scalars. + scalars._offsets = streamlines._offsets + scalars._lengths = streamlines._lengths if props_shape[0] == 0: # Because resizing an empty ndarray creates memory! @@ -334,7 +337,7 @@ def create_from_generator(cls, gen): else: properties.resize((i+1, props.shape[0])) - return cls(points, scalars, properties) + return cls(streamlines, scalars, properties) @property @@ -342,16 +345,16 @@ def header(self): return self._header @property - def points(self): - return self._points + def streamlines(self): + return self._streamlines - @points.setter - def points(self, value): - self._points = value + @streamlines.setter + def streamlines(self, value): + self._streamlines = value if not isinstance(value, CompactList): - self._points = CompactList(value) + self._streamlines = CompactList(value) - self.header.nb_streamlines = len(self.points) + self.header.nb_streamlines = len(self.streamlines) @property def scalars(self): @@ -382,11 +385,11 @@ def properties(self, value): self.header.nb_properties_per_streamline = len(self.properties[0]) def __iter__(self): - for data in zip_longest(self.points, self.scalars, self.properties, fillvalue=None): + for data in zip_longest(self.streamlines, self.scalars, self.properties, fillvalue=None): yield TractogramItem(*data) def __getitem__(self, idx): - pts = self.points[idx] + pts = self.streamlines[idx] scalars = [] if len(self.scalars) > 0: scalars = self.scalars[idx] @@ -401,11 +404,11 @@ def __getitem__(self, idx): return TractogramItem(pts, scalars, properties) def __len__(self): - return len(self.points) + return len(self.streamlines) def copy(self): """ Returns a copy of this `Tractogram` object. """ - streamlines = Tractogram(self.points.copy(), self.scalars.copy(), self.properties.copy()) + streamlines = Tractogram(self.streamlines.copy(), self.scalars.copy(), self.properties.copy()) streamlines._header = self.header.copy() return streamlines @@ -419,53 +422,53 @@ def apply_affine(self, affine): affine : 2D array (4,4) Transformation that will be applied on each streamline. """ - if len(self.points) == 0: + if len(self.streamlines) == 0: return - BUFFER_SIZE = 10000 - for i in range(0, len(self.points._data), BUFFER_SIZE): - pts = self.points._data[i:i+BUFFER_SIZE] - self.points._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) + BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. + for i in range(0, len(self.streamlines._data), BUFFER_SIZE): + pts = self.streamlines._data[i:i+BUFFER_SIZE] + self.streamlines._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) class LazyTractogram(Tractogram): ''' Class containing information about streamlines. - Tractogram objects have four main properties: ``header``, ``points``, + Tractogram objects have four main properties: ``header``, ``streamlines``, ``scalars`` and ``properties``. Tractogram objects are iterable and - produce tuple of ``points``, ``scalars`` and ``properties`` for each + produce tuple of ``streamlines``, ``scalars`` and ``properties`` for each streamline. Parameters ---------- - points_func : coroutine ouputting (N,3) array-like (optional) - Function yielding streamlines' points. One streamline's points is - an array-like of shape (N,3) where N is the number of points in a - streamline. - - scalars_func : coroutine ouputting (N,M) array-like (optional) - Function yielding streamlines' scalars. One streamline's scalars is - an array-like of shape (N,M) where N is the number of points for a - particular streamline and M is the number of scalars associated to - each point (excluding the three coordinates). + streamlines_func : coroutine ouputting (Nt,3) array-like (optional) + Function yielding streamlines. One streamline is + an ndarray of shape (Nt,3) where Nt is the number of points of + streamline t. + + scalars_func : coroutine ouputting (Nt,M) array-like (optional) + Function yielding scalars for a particular streamline t. The scalars + are represented as an ndarray of shape (Nt,M) where Nt is the number + of points of that streamline t and M is the number of scalars + associated to each point (excluding the three coordinates). properties_func : coroutine ouputting (P,) array-like (optional) - Function yielding streamlines' properties. One streamline's properties - is an array-like of shape (P,) where P is the number of properties - associated to each streamline. + Function yielding properties for a particular streamline t. The + properties are represented as an ndarray of shape (P,) where P is + the number of properties associated to each streamline. getitem_func : function `idx -> 3-tuples` (optional) - Function returning streamlines (one or a list of 3-tuples) given - an index or a slice (i.e. the __getitem__ function to use). + Function returning a subset of the tractogram given an index or a + slice (i.e. the __getitem__ function to use). Notes ----- If provided, ``scalars`` and ``properties`` must yield the same number of - values as ``points``. + values as ``streamlines``. ''' - def __init__(self, points_func=lambda:[], scalars_func=lambda: [], properties_func=lambda: [], getitem_func=None): - super(LazyTractogram, self).__init__(points_func, scalars_func, properties_func) - self._data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) + def __init__(self, streamlines_func=lambda:[], scalars_func=lambda: [], properties_func=lambda: [], getitem_func=None): + super(LazyTractogram, self).__init__(streamlines_func, scalars_func, properties_func) + self._data = lambda: zip_longest(self.streamlines, self.scalars, self.properties, fillvalue=[]) self._getitem = getitem_func @classmethod @@ -475,12 +478,11 @@ def create_from_data(cls, data_func): Parameters ---------- data_func : coroutine ouputting tuple (optional) - Function yielding 3-tuples, (streamline's points, streamline's - scalars, streamline's properties). A streamline's points is an - array-like of shape (N,3), a streamline's scalars is an array-like - of shape (N,M) and streamline's properties is an array-like of - shape (P,) where N is the number of points for a particular - streamline, M is the number of scalars associated to each point + Function yielding 3-tuples, (streamlines, scalars, properties). + Streamlines are represented as an ndarray of shape (Nt,3), scalars + as an ndarray of shape (Nt,M) and properties as an ndarray of shape + (P,) where Nt is the number of points for a particular + streamline t, M is the number of scalars associated to each point (excluding the three coordinates) and P is the number of properties associated to each streamline. ''' @@ -489,21 +491,21 @@ def create_from_data(cls, data_func): lazy_streamlines = cls() lazy_streamlines._data = data_func - lazy_streamlines.points = lambda: (x[0] for x in data_func()) + lazy_streamlines.streamlines = lambda: (x[0] for x in data_func()) lazy_streamlines.scalars = lambda: (x[1] for x in data_func()) lazy_streamlines.properties = lambda: (x[2] for x in data_func()) return lazy_streamlines @property - def points(self): - return self._points() + def streamlines(self): + return self._streamlines() - @points.setter - def points(self, value): + @streamlines.setter + def streamlines(self, value): if not callable(value): - raise TypeError("`points` must be a coroutine.") + raise TypeError("`streamlines` must be a coroutine.") - self._points = value + self._streamlines = value @property def scalars(self): @@ -564,12 +566,12 @@ def __len__(self): def copy(self): """ Returns a copy of this `LazyTractogram` object. """ - streamlines = LazyTractogram(self._points, self._scalars, self._properties) + streamlines = LazyTractogram(self._streamlines, self._scalars, self._properties) streamlines._header = self.header.copy() return streamlines def transform(self, affine): - """ Applies an affine transformation on the points of each streamline. + """ Applies an affine transformation on the streamlines. Parameters ---------- @@ -583,19 +585,29 @@ def transform(self, affine): """ return super(LazyTractogram, self).transform(affine, lazy=True) - def to_world_space(self): - """ Sends the streamlines back into world space. - Returns - ------- - streamlines : `LazyTractogram` object - Tractogram living in world space. - """ - return super(LazyTractogram, self).to_world_space(lazy=True) +class TractogramFile(object): + ''' Convenience class to encapsulate streamlines file format. ''' + __metaclass__ = ABCMeta + def __init__(self, tractogram): + self.tractogram = tractogram -class TractogramFile: - ''' Convenience class to encapsulate streamlines file format. ''' + @property + def streamlines(self): + return self.tractogram.streamlines + + @property + def scalars(self): + return self.tractogram.scalars + + @property + def properties(self): + return self.tractogram.properties + + @property + def header(self): + return self.tractogram.header @classmethod def get_magic_number(cls): @@ -692,15 +704,3 @@ def pretty_print(streamlines): Header information relevant to the streamlines file format. ''' raise NotImplementedError() - - -# class DynamicTractogramFile(TractogramFile): -# ''' Convenience class to encapsulate streamlines file format -# that supports appending streamlines to an existing file. -# ''' - -# def append(self, streamlines): -# raise NotImplementedError() - -# def __iadd__(self, streamlines): -# return self.append(streamlines) diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index 1aa650489e..99d8c8ea90 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -208,28 +208,28 @@ class TestTractogramItem(unittest.TestCase): def test_creating_tractogram_item(self): rng = np.random.RandomState(42) - points = rng.rand(rng.randint(10, 50), 3) - scalars = rng.rand(len(points), 5) + streamline = rng.rand(rng.randint(10, 50), 3) + scalars = rng.rand(len(streamline), 5) properties = rng.rand(42) # Create a streamline with only points - s = TractogramItem(points) - assert_equal(len(s), len(points)) - assert_array_equal(s.scalars, []) - assert_array_equal(s.properties, []) + t = TractogramItem(streamline) + assert_equal(len(t), len(streamline)) + assert_array_equal(t.scalars, []) + assert_array_equal(t.properties, []) # Create a streamline with points, scalars and properties. - s = TractogramItem(points, scalars, properties) - assert_equal(len(s), len(points)) - assert_array_equal(s.points, points) - assert_array_equal(list(s), points) - assert_equal(len(s), len(scalars)) - assert_array_equal(s.scalars, scalars) - assert_array_equal(s.properties, properties) + t = TractogramItem(streamline, scalars, properties) + assert_equal(len(t), len(streamline)) + assert_array_equal(t.streamline, streamline) + assert_array_equal(list(t), streamline) + assert_equal(len(t), len(scalars)) + assert_array_equal(t.scalars, scalars) + assert_array_equal(t.properties, properties) # # Create a streamline with different number of scalars. - # scalars = rng.rand(len(points)+3, 5) - # assert_raises(ValueError, TractogramItem, points, scalars) + # scalars = rng.rand(len(streamline)+3, 5) + # assert_raises(ValueError, TractogramItem, streamline, scalars) class TestTractogram(unittest.TestCase): @@ -239,9 +239,9 @@ def setUp(self): self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") - self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] + self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), np.array([(0, 1, 0)]*2, dtype="f4"), @@ -251,139 +251,138 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - self.nb_streamlines = len(self.points) + self.nb_tractogram = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) - def test_streamlines_creation_from_arrays(self): + def test_tractogram_creation_from_arrays(self): # Empty - streamlines = Tractogram() - assert_equal(len(streamlines), 0) - assert_arrays_equal(streamlines.points, []) - assert_arrays_equal(streamlines.scalars, []) - assert_arrays_equal(streamlines.properties, []) - - # Check if we can iterate through the streamlines. - for streamline in streamlines: + tractogram = Tractogram() + assert_equal(len(tractogram), 0) + assert_arrays_equal(tractogram.streamlines, []) + assert_arrays_equal(tractogram.scalars, []) + assert_arrays_equal(tractogram.properties, []) + + # Check if we can iterate through the tractogram. + for streamline in tractogram: pass - # Only points - streamlines = Tractogram(points=self.points) - assert_equal(len(streamlines), len(self.points)) - assert_arrays_equal(streamlines.points, self.points) - assert_arrays_equal(streamlines.scalars, []) - assert_arrays_equal(streamlines.properties, []) + # Only streamlines + tractogram = Tractogram(streamlines=self.streamlines) + assert_equal(len(tractogram), len(self.streamlines)) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.scalars, []) + assert_arrays_equal(tractogram.properties, []) - # Check if we can iterate through the streamlines. - for streamline in streamlines: + # Check if we can iterate through the tractogram. + for streamline in tractogram: pass # Points, scalars and properties - streamlines = Tractogram(self.points, self.colors, self.mean_curvature_torsion) - assert_equal(len(streamlines), len(self.points)) - assert_arrays_equal(streamlines.points, self.points) - assert_arrays_equal(streamlines.scalars, self.colors) - assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) - - # Check if we can iterate through the streamlines. - for streamline in streamlines: + tractogram = Tractogram(self.streamlines, self.colors, self.mean_curvature_torsion) + assert_equal(len(tractogram), len(self.streamlines)) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.scalars, self.colors) + assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) + + # Check if we can iterate through the tractogram. + for streamline in tractogram: pass - # Inconsistent number of scalars between points + # Inconsistent number of scalars between streamlines scalars = [[(1, 0, 0)]*1, [(0, 1, 0), (0, 1)], [(0, 0, 1)]*5] - assert_raises(ValueError, Tractogram, self.points, scalars) + assert_raises(ValueError, Tractogram, self.streamlines, scalars) # Unit test moved to test_base_format.py - # Inconsistent number of scalars between streamlines + # Inconsistent number of scalars between tractogram scalars = [[(1, 0, 0)]*1, [(0, 1)]*2, [(0, 0, 1)]*5] - assert_raises(ValueError, Tractogram, self.points, scalars) + assert_raises(ValueError, Tractogram, self.streamlines, scalars) - def test_streamlines_getter(self): - # Tractogram with only points - streamlines = Tractogram(points=self.points) + def test_tractogram_getter(self): + # Tractogram with only streamlines + tractogram = Tractogram(streamlines=self.streamlines) - selected_streamlines = streamlines[::2] - assert_equal(len(selected_streamlines), (len(self.points)+1)//2) + selected_tractogram = tractogram[::2] + assert_equal(len(selected_tractogram), (len(self.streamlines)+1)//2) - assert_arrays_equal(selected_streamlines.points, self.points[::2]) - assert_equal(sum(map(len, selected_streamlines.scalars)), 0) - assert_equal(sum(map(len, selected_streamlines.properties)), 0) + assert_arrays_equal(selected_tractogram.streamlines, self.streamlines[::2]) + assert_equal(sum(map(len, selected_tractogram.scalars)), 0) + assert_equal(sum(map(len, selected_tractogram.properties)), 0) - # Tractogram with points, scalars and properties - streamlines = Tractogram(self.points, self.colors, self.mean_curvature_torsion) + # Tractogram with streamlines, scalars and properties + tractogram = Tractogram(self.streamlines, self.colors, self.mean_curvature_torsion) - # Retrieve streamlines by their index - for i, streamline in enumerate(streamlines): - assert_array_equal(streamline.points, streamlines[i].points) - assert_array_equal(streamline.scalars, streamlines[i].scalars) - assert_array_equal(streamline.properties, streamlines[i].properties) + # Retrieve tractogram by their index + for i, t in enumerate(tractogram): + assert_array_equal(t.streamline, tractogram[i].streamline) + assert_array_equal(t.scalars, tractogram[i].scalars) + assert_array_equal(t.properties, tractogram[i].properties) # Use slicing - r_streamlines = streamlines[::-1] - assert_arrays_equal(r_streamlines.points, self.points[::-1]) - assert_arrays_equal(r_streamlines.scalars, self.colors[::-1]) - assert_arrays_equal(r_streamlines.properties, self.mean_curvature_torsion[::-1]) + r_tractogram = tractogram[::-1] + assert_arrays_equal(r_tractogram.streamlines, self.streamlines[::-1]) + assert_arrays_equal(r_tractogram.scalars, self.colors[::-1]) + assert_arrays_equal(r_tractogram.properties, self.mean_curvature_torsion[::-1]) - def test_streamlines_creation_from_generator(self): + def test_tractogram_creation_from_generator(self): # Create `Tractogram` from a generator yielding 3-tuples - gen = (x for x in zip(self.points, self.colors, self.mean_curvature_torsion)) + gen = (x for x in zip(self.streamlines, self.colors, self.mean_curvature_torsion)) - streamlines = Tractogram.create_from_generator(gen) + tractogram = Tractogram.create_from_generator(gen) with suppress_warnings(): - assert_equal(len(streamlines), self.nb_streamlines) + assert_equal(len(tractogram), self.nb_tractogram) - assert_arrays_equal(streamlines.points, self.points) - assert_arrays_equal(streamlines.scalars, self.colors) - assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.scalars, self.colors) + assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) - # Check if we can iterate through the streamlines. - for streamline in streamlines: + # Check if we can iterate through the tractogram. + for streamline in tractogram: pass - def test_streamlines_creation_from_coroutines(self): + def test_tractogram_creation_from_coroutines(self): # Points, scalars and properties - points = lambda: (x for x in self.points) + streamlines = lambda: (x for x in self.streamlines) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - # To create streamlines from multiple coroutines use `LazyTractogram`. - assert_raises(TypeError, Tractogram, points, scalars, properties) + # To create tractogram from multiple coroutines use `LazyTractogram`. + assert_raises(TypeError, Tractogram, streamlines, scalars, properties) def test_header(self): # Empty Tractogram, with default header - streamlines = Tractogram() - assert_equal(streamlines.header.nb_streamlines, 0) - assert_equal(streamlines.header.nb_scalars_per_point, 0) - assert_equal(streamlines.header.nb_properties_per_streamline, 0) - assert_array_equal(streamlines.header.voxel_sizes, (1, 1, 1)) - assert_array_equal(streamlines.header.to_world_space, np.eye(4)) - assert_equal(streamlines.header.extra, {}) + tractogram = Tractogram() + assert_equal(tractogram.header.nb_streamlines, 0) + assert_equal(tractogram.header.nb_scalars_per_point, 0) + assert_equal(tractogram.header.nb_properties_per_streamline, 0) + assert_array_equal(tractogram.header.voxel_sizes, (1, 1, 1)) + assert_array_equal(tractogram.header.to_world_space, np.eye(4)) + assert_equal(tractogram.header.extra, {}) - streamlines = Tractogram(self.points, self.colors, self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, self.colors, self.mean_curvature_torsion) - assert_equal(streamlines.header.nb_streamlines, len(self.points)) - assert_equal(streamlines.header.nb_scalars_per_point, self.colors[0].shape[1]) - assert_equal(streamlines.header.nb_properties_per_streamline, self.mean_curvature_torsion[0].shape[0]) + assert_equal(tractogram.header.nb_streamlines, len(self.streamlines)) + assert_equal(tractogram.header.nb_scalars_per_point, self.colors[0].shape[1]) + assert_equal(tractogram.header.nb_properties_per_streamline, self.mean_curvature_torsion[0].shape[0]) # Modifying voxel_sizes should be reflected in to_world_space - streamlines.header.voxel_sizes = (2, 3, 4) - assert_array_equal(streamlines.header.voxel_sizes, (2, 3, 4)) - assert_array_equal(np.diag(streamlines.header.to_world_space), (2, 3, 4, 1)) + tractogram.header.voxel_sizes = (2, 3, 4) + assert_array_equal(tractogram.header.voxel_sizes, (2, 3, 4)) + assert_array_equal(np.diag(tractogram.header.to_world_space), (2, 3, 4, 1)) # Modifying scaling of to_world_space should be reflected in voxel_sizes - streamlines.header.to_world_space = np.diag([4, 3, 2, 1]) - assert_array_equal(streamlines.header.voxel_sizes, (4, 3, 2)) - assert_array_equal(streamlines.header.to_world_space, np.diag([4, 3, 2, 1])) + tractogram.header.to_world_space = np.diag([4, 3, 2, 1]) + assert_array_equal(tractogram.header.voxel_sizes, (4, 3, 2)) + assert_array_equal(tractogram.header.to_world_space, np.diag([4, 3, 2, 1])) # Test that we can run __repr__ without error. - repr(streamlines.header) - + repr(tractogram.header) class TestLazyTractogram(unittest.TestCase): @@ -393,9 +392,9 @@ def setUp(self): self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") - self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] + self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), np.array([(0, 1, 0)]*2, dtype="f4"), @@ -405,99 +404,99 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - self.nb_streamlines = len(self.points) + self.nb_streamlines = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) - def test_lazy_streamlines_creation(self): - # To create streamlines from arrays use `Tractogram`. - assert_raises(TypeError, LazyTractogram, self.points) + def test_lazy_tractogram_creation(self): + # To create tractogram from arrays use `Tractogram`. + assert_raises(TypeError, LazyTractogram, self.streamlines) # Points, scalars and properties - points = (x for x in self.points) + streamlines = (x for x in self.streamlines) scalars = (x for x in self.colors) properties = (x for x in self.mean_curvature_torsion) # Creating LazyTractogram from generators is not allowed as # generators get exhausted and are not reusable unline coroutines. - assert_raises(TypeError, LazyTractogram, points) - assert_raises(TypeError, LazyTractogram, self.points, scalars) + assert_raises(TypeError, LazyTractogram, streamlines) + assert_raises(TypeError, LazyTractogram, self.streamlines, scalars) assert_raises(TypeError, LazyTractogram, properties_func=properties) # Empty `LazyTractogram` - streamlines = LazyTractogram() + tractogram = LazyTractogram() with suppress_warnings(): - assert_equal(len(streamlines), 0) - assert_arrays_equal(streamlines.points, []) - assert_arrays_equal(streamlines.scalars, []) - assert_arrays_equal(streamlines.properties, []) + assert_equal(len(tractogram), 0) + assert_arrays_equal(tractogram.streamlines, []) + assert_arrays_equal(tractogram.scalars, []) + assert_arrays_equal(tractogram.properties, []) - # Check if we can iterate through the streamlines. - for streamline in streamlines: + # Check if we can iterate through the tractogram. + for streamline in tractogram: pass # Points, scalars and properties - points = lambda: (x for x in self.points) + streamlines = lambda: (x for x in self.streamlines) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - streamlines = LazyTractogram(points, scalars, properties) + tractogram = LazyTractogram(streamlines, scalars, properties) with suppress_warnings(): - assert_equal(len(streamlines), self.nb_streamlines) + assert_equal(len(tractogram), self.nb_streamlines) # Coroutines get re-called and creates new iterators. - assert_arrays_equal(streamlines.points, self.points) - assert_arrays_equal(streamlines.scalars, self.colors) - assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) - assert_arrays_equal(streamlines.points, self.points) - assert_arrays_equal(streamlines.scalars, self.colors) - assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.scalars, self.colors) + assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.scalars, self.colors) + assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) # Create `LazyTractogram` from a coroutine yielding 3-tuples - data = lambda: (x for x in zip(self.points, self.colors, self.mean_curvature_torsion)) + data = lambda: (x for x in zip(self.streamlines, self.colors, self.mean_curvature_torsion)) - streamlines = LazyTractogram.create_from_data(data) + tractogram = LazyTractogram.create_from_data(data) with suppress_warnings(): - assert_equal(len(streamlines), self.nb_streamlines) - assert_arrays_equal(streamlines.points, self.points) - assert_arrays_equal(streamlines.scalars, self.colors) - assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + assert_equal(len(tractogram), self.nb_streamlines) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.scalars, self.colors) + assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) - # Check if we can iterate through the streamlines. - for streamline in streamlines: + # Check if we can iterate through the tractogram. + for streamline in tractogram: pass - def test_lazy_streamlines_indexing(self): - points = lambda: (x for x in self.points) + def test_lazy_tractogram_indexing(self): + streamlines = lambda: (x for x in self.streamlines) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) # By default, `LazyTractogram` object does not support indexing. - streamlines = LazyTractogram(points, scalars, properties) - assert_raises(AttributeError, streamlines.__getitem__, 0) + tractogram = LazyTractogram(streamlines, scalars, properties) + assert_raises(AttributeError, tractogram.__getitem__, 0) # Create a `LazyTractogram` object with indexing support. def getitem_without_properties(idx): if isinstance(idx, int) or isinstance(idx, np.integer): - return self.points[idx], self.colors[idx] + return self.streamlines[idx], self.colors[idx] - return list(zip(self.points[idx], self.colors[idx])) + return list(zip(self.streamlines[idx], self.colors[idx])) - streamlines = LazyTractogram(points, scalars, properties, getitem_without_properties) - points, scalars = streamlines[0] - assert_array_equal(points, self.points[0]) + tractogram = LazyTractogram(streamlines, scalars, properties, getitem_without_properties) + streamlines, scalars = tractogram[0] + assert_array_equal(streamlines, self.streamlines[0]) assert_array_equal(scalars, self.colors[0]) - points, scalars = zip(*streamlines[::-1]) - assert_arrays_equal(points, self.points[::-1]) + streamlines, scalars = zip(*tractogram[::-1]) + assert_arrays_equal(streamlines, self.streamlines[::-1]) assert_arrays_equal(scalars, self.colors[::-1]) - points, scalars = zip(*streamlines[:-1]) - assert_arrays_equal(points, self.points[:-1]) + streamlines, scalars = zip(*tractogram[:-1]) + assert_arrays_equal(streamlines, self.streamlines[:-1]) assert_arrays_equal(scalars, self.colors[:-1]) - def test_lazy_streamlines_len(self): - points = lambda: (x for x in self.points) + def test_lazy_tractogram_len(self): + streamlines = lambda: (x for x in self.streamlines) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) @@ -505,61 +504,61 @@ def test_lazy_streamlines_len(self): warnings.simplefilter("always") # Always trigger warnings. # Calling `len` will create new generators each time. - streamlines = LazyTractogram(points, scalars, properties) + tractogram = LazyTractogram(streamlines, scalars, properties) # This should produce a warning message. - assert_equal(len(streamlines), self.nb_streamlines) + assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 1) - streamlines = LazyTractogram(points, scalars, properties) + tractogram = LazyTractogram(streamlines, scalars, properties) # This should still produce a warning message. - assert_equal(len(streamlines), self.nb_streamlines) + assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 2) assert_true(issubclass(w[-1].category, UsageWarning)) # This should *not* produce a warning. - assert_equal(len(streamlines), self.nb_streamlines) + assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 2) with clear_and_catch_warnings(record=True, modules=[base_format]) as w: - # Once we iterated through the streamlines, we know the length. - streamlines = LazyTractogram(points, scalars, properties) - assert_true(streamlines.header.nb_streamlines is None) - for streamline in streamlines: + # Once we iterated through the tractogram, we know the length. + tractogram = LazyTractogram(streamlines, scalars, properties) + assert_true(tractogram.header.nb_streamlines is None) + for streamline in tractogram: pass - assert_equal(streamlines.header.nb_streamlines, len(self.points)) + assert_equal(tractogram.header.nb_streamlines, len(self.streamlines)) # This should *not* produce a warning. - assert_equal(len(streamlines), len(self.points)) + assert_equal(len(tractogram), len(self.streamlines)) assert_equal(len(w), 0) with clear_and_catch_warnings(record=True, modules=[base_format]) as w: - # It first checks if number of streamlines is in the header. - streamlines = LazyTractogram(points, scalars, properties) - streamlines.header.nb_streamlines = 1234 + # It first checks if number of tractogram is in the header. + tractogram = LazyTractogram(streamlines, scalars, properties) + tractogram.header.nb_streamlines = 1234 # This should *not* produce a warning. - assert_equal(len(streamlines), 1234) + assert_equal(len(tractogram), 1234) assert_equal(len(w), 0) - def test_lazy_streamlines_header(self): + def test_lazy_tractogram_header(self): # Empty `LazyTractogram`, with default header - streamlines = LazyTractogram() - assert_true(streamlines.header.nb_streamlines is None) - assert_equal(streamlines.header.nb_scalars_per_point, 0) - assert_equal(streamlines.header.nb_properties_per_streamline, 0) - assert_array_equal(streamlines.header.voxel_sizes, (1, 1, 1)) - assert_array_equal(streamlines.header.to_world_space, np.eye(4)) - assert_equal(streamlines.header.extra, {}) - - points = lambda: (x for x in self.points) + tractogram = LazyTractogram() + assert_true(tractogram.header.nb_streamlines is None) + assert_equal(tractogram.header.nb_scalars_per_point, 0) + assert_equal(tractogram.header.nb_properties_per_streamline, 0) + assert_array_equal(tractogram.header.voxel_sizes, (1, 1, 1)) + assert_array_equal(tractogram.header.to_world_space, np.eye(4)) + assert_equal(tractogram.header.extra, {}) + + streamlines = lambda: (x for x in self.streamlines) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - streamlines = LazyTractogram(points) - header = streamlines.header + tractogram = LazyTractogram(streamlines) + header = tractogram.header assert_equal(header.nb_scalars_per_point, 0) - streamlines.scalars = scalars + tractogram.scalars = scalars assert_equal(header.nb_scalars_per_point, self.nb_scalars_per_point) assert_equal(header.nb_properties_per_streamline, 0) - streamlines.properties = properties + tractogram.properties = properties assert_equal(header.nb_properties_per_streamline, self.nb_properties_per_streamline) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 1e94b79911..5cd30c65f6 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -9,7 +9,7 @@ from nibabel.externals.six import BytesIO from nibabel.testing import clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal +from nibabel.testing import assert_arrays_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true, assert_false from ..base_format import Tractogram, LazyTractogram @@ -20,29 +20,19 @@ DATA_PATH = pjoin(os.path.dirname(__file__), 'data') -def isiterable(streamlines): - try: - for _ in streamlines: - pass - except: - return False - - return True - - -def check_streamlines(streamlines, nb_streamlines, points, scalars, properties): +def check_tractogram(tractogram, nb_streamlines, streamlines, scalars, properties): # Check data - assert_equal(len(streamlines), nb_streamlines) - assert_arrays_equal(streamlines.points, points) - assert_arrays_equal(streamlines.scalars, scalars) - assert_arrays_equal(streamlines.properties, properties) - assert_true(isiterable(streamlines)) + assert_equal(len(tractogram), nb_streamlines) + assert_arrays_equal(tractogram.streamlines, streamlines) + assert_arrays_equal(tractogram.scalars, scalars) + assert_arrays_equal(tractogram.properties, properties) + assert_true(isiterable(tractogram)) - assert_equal(streamlines.header.nb_streamlines, nb_streamlines) + assert_equal(tractogram.header.nb_streamlines, nb_streamlines) nb_scalars_per_point = 0 if len(scalars) == 0 else len(scalars[0][0]) nb_properties_per_streamline = 0 if len(properties) == 0 else len(properties[0]) - assert_equal(streamlines.header.nb_scalars_per_point, nb_scalars_per_point) - assert_equal(streamlines.header.nb_properties_per_streamline, nb_properties_per_streamline) + assert_equal(tractogram.header.nb_scalars_per_point, nb_scalars_per_point) + assert_equal(tractogram.header.nb_properties_per_streamline, nb_properties_per_streamline) def test_is_supported(): @@ -52,33 +42,33 @@ def test_is_supported(): assert_false(nib.streamlines.is_supported("")) # Valid file without extension - for streamlines_file in nib.streamlines.FORMATS.values(): + for tractogram_file in nib.streamlines.FORMATS.values(): f = BytesIO() - f.write(streamlines_file.get_magic_number()) + f.write(tractogram_file.get_magic_number()) f.seek(0, os.SEEK_SET) assert_true(nib.streamlines.is_supported(f)) # Wrong extension but right magic number - for streamlines_file in nib.streamlines.FORMATS.values(): + for tractogram_file in nib.streamlines.FORMATS.values(): with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: - f.write(streamlines_file.get_magic_number()) + f.write(tractogram_file.get_magic_number()) f.seek(0, os.SEEK_SET) assert_true(nib.streamlines.is_supported(f)) # Good extension but wrong magic number - for ext, streamlines_file in nib.streamlines.FORMATS.items(): + for ext, tractogram_file in nib.streamlines.FORMATS.items(): with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: f.write(b"pass") f.seek(0, os.SEEK_SET) assert_false(nib.streamlines.is_supported(f)) # Wrong extension, string only - f = "my_streamlines.asd" + f = "my_tractogram.asd" assert_false(nib.streamlines.is_supported(f)) # Good extension, string only - for ext, streamlines_file in nib.streamlines.FORMATS.items(): - f = "my_streamlines" + ext + for ext, tractogram_file in nib.streamlines.FORMATS.items(): + f = "my_tractogram" + ext assert_true(nib.streamlines.is_supported(f)) @@ -89,34 +79,34 @@ def test_detect_format(): assert_equal(nib.streamlines.detect_format(""), None) # Valid file without extension - for streamlines_file in nib.streamlines.FORMATS.values(): + for tractogram_file in nib.streamlines.FORMATS.values(): f = BytesIO() - f.write(streamlines_file.get_magic_number()) + f.write(tractogram_file.get_magic_number()) f.seek(0, os.SEEK_SET) - assert_equal(nib.streamlines.detect_format(f), streamlines_file) + assert_equal(nib.streamlines.detect_format(f), tractogram_file) # Wrong extension but right magic number - for streamlines_file in nib.streamlines.FORMATS.values(): + for tractogram_file in nib.streamlines.FORMATS.values(): with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: - f.write(streamlines_file.get_magic_number()) + f.write(tractogram_file.get_magic_number()) f.seek(0, os.SEEK_SET) - assert_equal(nib.streamlines.detect_format(f), streamlines_file) + assert_equal(nib.streamlines.detect_format(f), tractogram_file) # Good extension but wrong magic number - for ext, streamlines_file in nib.streamlines.FORMATS.items(): + for ext, tractogram_file in nib.streamlines.FORMATS.items(): with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: f.write(b"pass") f.seek(0, os.SEEK_SET) assert_equal(nib.streamlines.detect_format(f), None) # Wrong extension, string only - f = "my_streamlines.asd" + f = "my_tractogram.asd" assert_equal(nib.streamlines.detect_format(f), None) # Good extension, string only - for ext, streamlines_file in nib.streamlines.FORMATS.items(): - f = "my_streamlines" + ext - assert_equal(nib.streamlines.detect_format(f), streamlines_file) + for ext, tractogram_file in nib.streamlines.FORMATS.items(): + f = "my_tractogram" + ext + assert_equal(nib.streamlines.detect_format(f), tractogram_file) class TestLoadSave(unittest.TestCase): @@ -125,9 +115,9 @@ def setUp(self): self.simple_filenames = [pjoin(DATA_PATH, "simple" + ext) for ext in nib.streamlines.FORMATS.keys()] self.complex_filenames = [pjoin(DATA_PATH, "complex" + ext) for ext in nib.streamlines.FORMATS.keys()] - self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] + self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), np.array([(0, 1, 0)]*2, dtype="f4"), @@ -137,35 +127,35 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - self.nb_streamlines = len(self.points) + self.nb_streamlines = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) self.to_world_space = np.eye(4) def test_load_empty_file(self): for empty_filename in self.empty_filenames: - streamlines = nib.streamlines.load(empty_filename, + tractogram = nib.streamlines.load(empty_filename, ref=self.to_world_space, lazy_load=False) - assert_true(type(streamlines), Tractogram) - check_streamlines(streamlines, 0, [], [], []) + assert_true(type(tractogram), Tractogram) + check_tractogram(tractogram, 0, [], [], []) def test_load_simple_file(self): for simple_filename in self.simple_filenames: - streamlines = nib.streamlines.load(simple_filename, + tractogram = nib.streamlines.load(simple_filename, ref=self.to_world_space, lazy_load=False) - assert_true(type(streamlines), Tractogram) - check_streamlines(streamlines, self.nb_streamlines, - self.points, [], []) + assert_true(type(tractogram), Tractogram) + check_tractogram(tractogram, self.nb_streamlines, + self.streamlines, [], []) # Test lazy_load - streamlines = nib.streamlines.load(simple_filename, + tractogram = nib.streamlines.load(simple_filename, ref=self.to_world_space, lazy_load=True) - assert_true(type(streamlines), LazyTractogram) - check_streamlines(streamlines, self.nb_streamlines, - self.points, [], []) + assert_true(type(tractogram), LazyTractogram) + check_tractogram(tractogram, self.nb_streamlines, + self.streamlines, [], []) def test_load_complex_file(self): for complex_filename in self.complex_filenames: @@ -179,36 +169,36 @@ def test_load_complex_file(self): if file_format.can_save_properties(): properties = self.mean_curvature_torsion - streamlines = nib.streamlines.load(complex_filename, + tractogram = nib.streamlines.load(complex_filename, ref=self.to_world_space, lazy_load=False) - assert_true(type(streamlines), Tractogram) - check_streamlines(streamlines, self.nb_streamlines, - self.points, scalars, properties) + assert_true(type(tractogram), Tractogram) + check_tractogram(tractogram, self.nb_streamlines, + self.streamlines, scalars, properties) # Test lazy_load - streamlines = nib.streamlines.load(complex_filename, + tractogram = nib.streamlines.load(complex_filename, ref=self.to_world_space, lazy_load=True) - assert_true(type(streamlines), LazyTractogram) - check_streamlines(streamlines, self.nb_streamlines, - self.points, scalars, properties) + assert_true(type(tractogram), LazyTractogram) + check_tractogram(tractogram, self.nb_streamlines, + self.streamlines, scalars, properties) def test_save_simple_file(self): - streamlines = Tractogram(self.points) + tractogram = Tractogram(self.streamlines) for ext in nib.streamlines.FORMATS.keys(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: - nib.streamlines.save(streamlines, f.name) - loaded_streamlines = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) - check_streamlines(loaded_streamlines, self.nb_streamlines, - self.points, [], []) + nib.streamlines.save(tractogram, f.name) + loaded_tractogram = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) + check_tractogram(loaded_tractogram, self.nb_streamlines, + self.streamlines, [], []) def test_save_complex_file(self): - streamlines = Tractogram(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, scalars=self.colors, properties=self.mean_curvature_torsion) for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: with clear_and_catch_warnings(record=True, modules=[trk]) as w: - nib.streamlines.save(streamlines, f.name) + nib.streamlines.save(tractogram, f.name) # If streamlines format does not support saving scalars or # properties, a warning message should be issued. @@ -224,6 +214,6 @@ def test_save_complex_file(self): if cls.can_save_properties(): properties = self.mean_curvature_torsion - loaded_streamlines = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) - check_streamlines(loaded_streamlines, self.nb_streamlines, - self.points, scalars, properties) + loaded_tractogram = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) + check_tractogram(loaded_tractogram, self.nb_streamlines, + self.streamlines, scalars, properties) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 0c5d58aadf..40720027fc 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -5,7 +5,7 @@ from nibabel.externals.six import BytesIO from nibabel.testing import suppress_warnings, clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal, assert_streamlines_equal +from nibabel.testing import assert_arrays_equal, assert_tractogram_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true from .. import base_format @@ -18,43 +18,33 @@ DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') -def isiterable(streamlines): - try: - for streamline in streamlines: - pass - except: - return False - - return True - - -def check_streamlines(streamlines, nb_streamlines, points, scalars, properties): +def check_tractogram(tractogram, nb_streamlines, streamlines, scalars, properties): # Check data - assert_equal(len(streamlines), nb_streamlines) - assert_arrays_equal(streamlines.points, points) - assert_arrays_equal(streamlines.scalars, scalars) - assert_arrays_equal(streamlines.properties, properties) - assert_true(isiterable(streamlines)) + assert_equal(len(tractogram), nb_streamlines) + assert_arrays_equal(tractogram.streamlines, streamlines) + assert_arrays_equal(tractogram.scalars, scalars) + assert_arrays_equal(tractogram.properties, properties) + assert_true(isiterable(tractogram)) - assert_equal(streamlines.header.nb_streamlines, nb_streamlines) + assert_equal(tractogram.header.nb_streamlines, nb_streamlines) nb_scalars_per_point = 0 if len(scalars) == 0 else len(scalars[0][0]) nb_properties_per_streamline = 0 if len(properties) == 0 else len(properties[0]) - assert_equal(streamlines.header.nb_scalars_per_point, nb_scalars_per_point) - assert_equal(streamlines.header.nb_properties_per_streamline, nb_properties_per_streamline) + assert_equal(tractogram.header.nb_scalars_per_point, nb_scalars_per_point) + assert_equal(tractogram.header.nb_properties_per_streamline, nb_properties_per_streamline) class TestTRK(unittest.TestCase): def setUp(self): self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") - # simple.trk contains only points + # simple.trk contains only streamlines self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") - # complex.trk contains points, scalars and properties + # complex.trk contains streamlines, scalars and properties self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") - self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] + self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), np.array([(0, 1, 0)]*2, dtype="f4"), @@ -64,34 +54,34 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - self.nb_streamlines = len(self.points) + self.nb_streamlines = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) def test_load_empty_file(self): trk = TrkFile.load(self.empty_trk_filename, ref=None, lazy_load=False) - check_streamlines(trk, 0, [], [], []) + check_tractogram(trk, 0, [], [], []) trk = TrkFile.load(self.empty_trk_filename, ref=None, lazy_load=True) # Suppress warning about loading a TRK file in lazy mode with count=0. with suppress_warnings(): - check_streamlines(trk, 0, [], [], []) + check_tractogram(trk, 0, [], [], []) def test_load_simple_file(self): trk = TrkFile.load(self.simple_trk_filename, ref=None, lazy_load=False) - check_streamlines(trk, self.nb_streamlines, self.points, [], []) + check_tractogram(trk, self.nb_streamlines, self.streamlines, [], []) trk = TrkFile.load(self.simple_trk_filename, ref=None, lazy_load=True) - check_streamlines(trk, self.nb_streamlines, self.points, [], []) + check_tractogram(trk, self.nb_streamlines, self.streamlines, [], []) def test_load_complex_file(self): trk = TrkFile.load(self.complex_trk_filename, ref=None, lazy_load=False) - check_streamlines(trk, self.nb_streamlines, - self.points, self.colors, self.mean_curvature_torsion) + check_tractogram(trk, self.nb_streamlines, + self.streamlines, self.colors, self.mean_curvature_torsion) trk = TrkFile.load(self.complex_trk_filename, ref=None, lazy_load=True) - check_streamlines(trk, self.nb_streamlines, - self.points, self.colors, self.mean_curvature_torsion) + check_tractogram(trk, self.nb_streamlines, + self.streamlines, self.colors, self.mean_curvature_torsion) def test_load_file_with_wrong_information(self): trk_file = open(self.simple_trk_filename, 'rb').read() @@ -99,12 +89,12 @@ def test_load_file_with_wrong_information(self): # Simulate a TRK file where `count` was not provided. count = np.array(0, dtype="int32").tostring() new_trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] - streamlines = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) - check_streamlines(streamlines, self.nb_streamlines, self.points, [], []) + tractogram = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) + check_tractogram(tractogram, self.nb_streamlines, self.streamlines, [], []) - streamlines = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) + tractogram = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) with clear_and_catch_warnings(record=True, modules=[base_format]) as w: - check_streamlines(streamlines, self.nb_streamlines, self.points, [], []) + check_tractogram(tractogram, self.nb_streamlines, self.streamlines, [], []) assert_equal(len(w), 1) assert_true(issubclass(w[0].category, UsageWarning)) @@ -128,59 +118,59 @@ def test_load_file_with_wrong_information(self): assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) def test_write_simple_file(self): - streamlines = Tractogram(self.points) + tractogram = Tractogram(self.streamlines) trk_file = BytesIO() - TrkFile.save(streamlines, trk_file) + TrkFile.save(tractogram, trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_streamlines = TrkFile.load(trk_file) - check_streamlines(loaded_streamlines, self.nb_streamlines, - self.points, [], []) + loaded_tractogram = TrkFile.load(trk_file) + check_tractogram(loaded_tractogram, self.nb_streamlines, + self.streamlines, [], []) - loaded_streamlines_orig = TrkFile.load(self.simple_trk_filename) - assert_streamlines_equal(loaded_streamlines, loaded_streamlines_orig) + loaded_tractogram_orig = TrkFile.load(self.simple_trk_filename) + assert_tractogram_equal(loaded_tractogram, loaded_tractogram_orig) trk_file.seek(0, os.SEEK_SET) assert_equal(open(self.simple_trk_filename, 'rb').read(), trk_file.read()) def test_write_complex_file(self): # With scalars - streamlines = Tractogram(self.points, scalars=self.colors) + tractogram = Tractogram(self.streamlines, scalars=self.colors) trk_file = BytesIO() - TrkFile.save(streamlines, trk_file) + TrkFile.save(tractogram, trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_streamlines = TrkFile.load(trk_file, ref=None, lazy_load=False) + loaded_tractogram = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_streamlines(loaded_streamlines, self.nb_streamlines, - self.points, self.colors, []) + check_tractogram(loaded_tractogram, self.nb_streamlines, + self.streamlines, self.colors, []) # With properties - streamlines = Tractogram(self.points, properties=self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, properties=self.mean_curvature_torsion) trk_file = BytesIO() - TrkFile.save(streamlines, trk_file) + TrkFile.save(tractogram, trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_streamlines = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_streamlines(loaded_streamlines, self.nb_streamlines, - self.points, [], self.mean_curvature_torsion) + loaded_tractogram = TrkFile.load(trk_file, ref=None, lazy_load=False) + check_tractogram(loaded_tractogram, self.nb_streamlines, + self.streamlines, [], self.mean_curvature_torsion) # With scalars and properties - streamlines = Tractogram(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, scalars=self.colors, properties=self.mean_curvature_torsion) trk_file = BytesIO() - TrkFile.save(streamlines, trk_file) + TrkFile.save(tractogram, trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_streamlines = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_streamlines(loaded_streamlines, self.nb_streamlines, - self.points, self.colors, self.mean_curvature_torsion) + loaded_tractogram = TrkFile.load(trk_file, ref=None, lazy_load=False) + check_tractogram(loaded_tractogram, self.nb_streamlines, + self.streamlines, self.colors, self.mean_curvature_torsion) - loaded_streamlines_orig = TrkFile.load(self.complex_trk_filename) - assert_streamlines_equal(loaded_streamlines, loaded_streamlines_orig) + loaded_tractogram_orig = TrkFile.load(self.complex_trk_filename) + assert_tractogram_equal(loaded_tractogram, loaded_tractogram_orig) trk_file.seek(0, os.SEEK_SET) assert_equal(open(self.complex_trk_filename, 'rb').read(), trk_file.read()) @@ -191,15 +181,15 @@ def test_write_erroneous_file(self): [(0, 1, 0)], [(0, 0, 1)]] - streamlines = Tractogram(self.points, scalars) - assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + tractogram = Tractogram(self.streamlines, scalars) + assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) # No scalars for every streamlines scalars = [[(1, 0, 0)]*1, [(0, 1, 0)]*2] - streamlines = Tractogram(self.points, scalars) - assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + tractogram = Tractogram(self.streamlines, scalars) + assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) # # Unit test moved to test_base_format.py # # Inconsistent number of scalars between points @@ -207,8 +197,8 @@ def test_write_erroneous_file(self): # [(0, 1, 0), (0, 1)], # [(0, 0, 1)]*5] - # streamlines = Tractogram(self.points, scalars) - # assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) + # tractogram = Tractogram(self.streamlines, scalars) + # assert_raises(ValueError, TrkFile.save, tractogram, BytesIO()) # # Unit test moved to test_base_format.py # # Inconsistent number of scalars between streamlines @@ -216,38 +206,38 @@ def test_write_erroneous_file(self): # [(0, 1)]*2, # [(0, 0, 1)]*5] - # streamlines = Tractogram(self.points, scalars) - # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + # tractogram = Tractogram(self.streamlines, scalars) + # assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) # Unit test moved to test_base_format.py # Inconsistent number of properties properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - streamlines = Tractogram(self.points, properties=properties) - assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + tractogram = Tractogram(self.streamlines, properties=properties) + assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) # Unit test moved to test_base_format.py # No properties for every streamlines properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11, 2.22], dtype="f4")] - streamlines = Tractogram(self.points, properties=properties) - assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + tractogram = Tractogram(self.streamlines, properties=properties) + assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) - def test_write_file_lazy_streamlines(self): - points = lambda: (point for point in self.points) + def test_write_file_lazy_tractogram(self): + streamlines = lambda: (point for point in self.streamlines) scalars = lambda: (scalar for scalar in self.colors) properties = lambda: (prop for prop in self.mean_curvature_torsion) - streamlines = LazyTractogram(points, scalars, properties) + tractogram = LazyTractogram(streamlines, scalars, properties) # No need to manually set `nb_streamlines` in the header since we count # them as writing. - #streamlines.header.nb_streamlines = self.nb_streamlines + #tractogram.header.nb_streamlines = self.nb_streamlines trk_file = BytesIO() - TrkFile.save(streamlines, trk_file) + TrkFile.save(tractogram, trk_file) trk_file.seek(0, os.SEEK_SET) trk = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_streamlines(trk, self.nb_streamlines, - self.points, self.colors, self.mean_curvature_torsion) + check_tractogram(trk, self.nb_streamlines, + self.streamlines, self.colors, self.mean_curvature_torsion) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 3194218190..ae913ad67f 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -250,17 +250,17 @@ def __init__(self, fileobj, header): self.beginning = self.file.tell() self.file.write(self.header[0].tostring()) - def write(self, streamlines): + def write(self, tractogram): i4_dtype = np.dtype("i4") f4_dtype = np.dtype("f4") - for s in streamlines: - if len(s.scalars) > 0 and len(s.scalars) != len(s.points): + for t in tractogram: + if len(t.scalars) > 0 and len(t.scalars) != len(t.streamline): raise DataError("Missing scalars for some points!") - points = np.asarray(s.points, dtype=f4_dtype) - scalars = np.asarray(s.scalars, dtype=f4_dtype).reshape((len(points), -1)) - properties = np.asarray(s.properties, dtype=f4_dtype) + points = np.asarray(t.streamline, dtype=f4_dtype) + scalars = np.asarray(t.scalars, dtype=f4_dtype).reshape((len(points), -1)) + properties = np.asarray(t.properties, dtype=f4_dtype) # TRK's streamlines need to be in 'voxelmm' space points = points * self.header[Field.VOXEL_SIZES] @@ -377,8 +377,8 @@ def load(fileobj, ref=None, lazy_load=False): Returns ------- - streamlines : Tractogram object - Returns an object containing streamlines' data and header + tractogram : Tractogram object + Returns an object containing tractogram' data and header information. See `nibabel.Tractogram`. Notes @@ -412,13 +412,13 @@ def _apply_transform(trk_reader): yield pts, scals, props data = lambda: _apply_transform(trk_reader) - streamlines = LazyTractogram.create_from_data(data) + tractogram = LazyTractogram.create_from_data(data) # Overwrite scalars and properties if there is none if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: - streamlines.scalars = lambda: [] + tractogram.scalars = lambda: [] if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: - streamlines.properties = lambda: [] + tractogram.properties = lambda: [] # elif Field.NB_POINTS in trk_reader.header: # # 'count' field is provided, we can avoid creating list of numpy @@ -464,46 +464,46 @@ def _apply_transform(trk_reader): # streamlines.properties = [] else: - streamlines = Tractogram.create_from_generator(trk_reader) - #streamlines = Tractogram(*zip(*trk_reader)) - streamlines.apply_affine(affine) + tractogram = Tractogram.create_from_generator(trk_reader) + #tractogram = Tractogram(*zip(*trk_reader)) + tractogram.apply_affine(affine) # Overwrite scalars and properties if there is none if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: - streamlines.scalars = [] + tractogram.scalars = [] if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: - streamlines.properties = [] + tractogram.properties = [] # Set available common information about streamlines in the header - streamlines.header.to_world_space = affine + tractogram.header.to_world_space = affine # If 'count' field is 0, i.e. not provided, we don't set `nb_streamlines` if trk_reader.header[Field.NB_STREAMLINES] > 0: - streamlines.header.nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] + tractogram.header.nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] # Keep extra information about TRK format - streamlines.header.extra = trk_reader.header + tractogram.header.extra = trk_reader.header ## Perform some integrity checks - #if trk_reader.header[Field.VOXEL_ORDER] != streamlines.header.voxel_order: + #if trk_reader.header[Field.VOXEL_ORDER] != tractogram.header.voxel_order: # raise HeaderError("'voxel_order' does not match the affine.") - #if streamlines.header.voxel_sizes != trk_reader.header[Field.VOXEL_SIZES]: + #if tractogram.header.voxel_sizes != trk_reader.header[Field.VOXEL_SIZES]: # raise HeaderError("'voxel_sizes' does not match the affine.") - #if streamlines.header.nb_scalars_per_point != trk_reader.header[Field.NB_SCALARS_PER_POINT]: + #if tractogram.header.nb_scalars_per_point != trk_reader.header[Field.NB_SCALARS_PER_POINT]: # raise HeaderError("'nb_scalars_per_point' does not match.") - #if streamlines.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: + #if tractogram.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: # raise HeaderError("'nb_properties_per_streamline' does not match.") - return streamlines + return tractogram @staticmethod - def save(streamlines, fileobj, ref=None): - ''' Saves streamlines to a file-like object. + def save(tractogram, fileobj, ref=None): + ''' Saves tractogram to a file-like object. Parameters ---------- - streamlines : Tractogram object - Object containing streamlines' data and header information. + tractogram : Tractogram object + Object containing tractogram' data and header information. See 'nibabel.Tractogram'. fileobj : string or file-like object @@ -520,10 +520,10 @@ def save(streamlines, fileobj, ref=None): refers to the center of the voxel. ''' if ref is not None: - streamlines.header.to_world_space = get_affine_from_reference(ref) + tractogram.header.to_world_space = get_affine_from_reference(ref) - trk_writer = TrkWriter(fileobj, streamlines.header) - trk_writer.write(streamlines) + trk_writer = TrkWriter(fileobj, tractogram.header) + trk_writer.write(tractogram) @staticmethod def pretty_print(fileobj): diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index 3de3809477..935ae4ad00 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -59,15 +59,25 @@ def assert_allclose_safely(a, b, match_nans=True): assert_true(np.allclose(a, b)) +def isiterable(iterable): + try: + for _ in iterable: + pass + except: + return False + + return True + + def assert_arrays_equal(arrays1, arrays2): for arr1, arr2 in zip_longest(arrays1, arrays2, fillvalue=None): assert_array_equal(arr1, arr2) -def assert_streamlines_equal(s1, s2): +def assert_tractogram_equal(s1, s2): assert_equal(s1.header, s2.header) assert_equal(len(s1), len(s2)) - assert_arrays_equal(s1.points, s2.points) + assert_arrays_equal(s1.streamlines, s2.streamlines) assert_arrays_equal(s1.scalars, s2.scalars) assert_arrays_equal(s1.properties, s2.properties) From 9fc0f8605c3b6c9f7ee7f32372badbc7fcffb46b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 2 Nov 2015 17:09:19 -0500 Subject: [PATCH 16/54] Refactoring StreamingFile --- nibabel/streamlines/__init__.py | 88 ++++++------- nibabel/streamlines/base_format.py | 66 ++++------ nibabel/streamlines/tests/test_streamlines.py | 66 +++++----- nibabel/streamlines/tests/test_trk.py | 117 ++++++++++-------- nibabel/streamlines/trk.py | 92 +++++++++----- 5 files changed, 224 insertions(+), 205 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index e23757013f..a58f72901c 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,4 +1,4 @@ -from .header import Field +from .header import TractogramHeader from .base_format import Tractogram, LazyTractogram from nibabel.streamlines.trk import TrkFile @@ -36,14 +36,14 @@ def detect_format(fileobj): ---------- fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing - to a streamlines file (and ready to read from the beginning of the + to a tractogram file (and ready to read from the beginning of the header) Returns ------- - streamlines_file : StreamlinesFile object - Object that can be used to manage a streamlines file. - See 'nibabel.streamlines.StreamlinesFile'. + tractogram_file : ``TractogramFile`` class + Returns an instance of a `TractogramFile` class containing data and + metadata of the tractogram contained from `fileobj`. ''' for format in FORMATS.values(): try: @@ -62,76 +62,68 @@ def detect_format(fileobj): return None -def load(fileobj, ref, lazy_load=False): +def load(fileobj, lazy_load=False, ref=None): ''' Loads streamlines from a file-like object in voxel space. Parameters ---------- fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the beginning - of the streamlines file's header). - - ref : filename | `Nifti1Image` object | 2D array (4,4) - Reference space where streamlines will live in `fileobj`. + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the beginning + of the streamlines file's header). lazy_load : boolean (optional) Load streamlines in a lazy manner i.e. they will not be kept in memory. + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines will live in `fileobj`. + Returns ------- - obj : instance of `Streamlines` - Returns an instance of a `Streamlines` class containing data and metadata - of streamlines loaded from `fileobj`. + tractogram_file : ``TractogramFile`` + Returns an instance of a `TractogramFile` class containing data and + metadata of the tractogram loaded from `fileobj`. ''' - streamlines_file = detect_format(fileobj) + tractogram_file = detect_format(fileobj) - if streamlines_file is None: - raise TypeError("Unknown format for 'fileobj': {0}!".format(fileobj)) + if tractogram_file is None: + raise TypeError("Unknown format for 'fileobj': {}".format(fileobj)) - return streamlines_file.load(fileobj, ref, lazy_load=lazy_load) + return tractogram_file.load(fileobj, lazy_load=lazy_load, ref=ref) -def save(streamlines, filename, ref=None): - ''' Saves a `Streamlines` object to a file +def save(tractogram_file, filename): + ''' Saves a tractogram to a file. Parameters ---------- - streamlines : `Streamlines` object - Streamlines to be saved. + tractogram_file : ``TractogramFile`` object + Tractogram to be saved on disk. filename : str - Name of the file where the streamlines will be saved. The format will - be guessed from `filename`. - - ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) - Reference space where streamlines will live in `fileobj`. + Name of the file where the tractogram will be saved. The format will + be guessed from `filename`. ''' - streamlines_file = detect_format(filename) + tractogram_file.save(filename) - if streamlines_file is None: - raise TypeError("Unknown streamlines file format: '{0}'!".format(filename)) - streamlines_file.save(streamlines, filename, ref) - - -def convert(in_fileobj, out_filename, ref): - ''' Converts a streamlines file to another format. +def save_tractogram(tractogram, filename, **kwargs): + ''' Saves a tractogram to a file. Parameters ---------- - in_fileobj : string or file-like object - If string, a filename; otherwise an open file-like object pointing - to a streamlines file (and ready to read from the beginning of the - header). + tractogram : ``Tractogram`` object + Tractogram to be saved. - out_filename : str - Name of the file where the streamlines will be saved. The format will - be guessed from `out_filename`. - - ref : filename | `Nifti1Image` object | 2D array (4,4) - Reference space where streamlines live in `fileobj`. + filename : str + Name of the file where the tractogram will be saved. The format will + be guessed from `filename`. ''' - streamlines = load(in_fileobj, ref, lazy_load=True) - save(streamlines, out_filename) + tractogram_file_class = detect_format(filename) + + if tractogram_file_class is None: + raise TypeError("Unknown tractogram file format: '{}'".format(filename)) + + tractogram_file = tractogram_file_class(tractogram, **kwargs) + tractogram_file.save(filename) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 42e34d0a37..4077685dc0 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -339,7 +339,6 @@ def create_from_generator(cls, gen): return cls(streamlines, scalars, properties) - @property def header(self): return self._header @@ -586,12 +585,25 @@ def transform(self, affine): return super(LazyTractogram, self).transform(affine, lazy=True) +class abstractclassmethod(classmethod): + __isabstractmethod__ = True + + def __init__(self, callable): + callable.__isabstractmethod__ = True + super(abstractclassmethod, self).__init__(callable) + + class TractogramFile(object): - ''' Convenience class to encapsulate streamlines file format. ''' + ''' Convenience class to encapsulate tractogram file format. ''' __metaclass__ = ABCMeta - def __init__(self, tractogram): - self.tractogram = tractogram + def __init__(self, tractogram, header=None): + self._tractogram = tractogram + self._header = TractogramHeader() if header is None else header + + @property + def tractogram(self): + return self._tractogram @property def streamlines(self): @@ -607,7 +619,7 @@ def properties(self): @property def header(self): - return self.tractogram.header + return self._header @classmethod def get_magic_number(cls): @@ -642,8 +654,8 @@ def is_correct_format(cls, fileobj): ''' raise NotImplementedError() - @staticmethod - def load(fileobj, ref, lazy_load=True): + @abstractclassmethod + def load(cls, fileobj, lazy_load=True): ''' Loads streamlines from a file-like object. Parameters @@ -653,54 +665,26 @@ def load(fileobj, ref, lazy_load=True): pointing to a streamlines file (and ready to read from the beginning of the header). - ref : filename | `Nifti1Image` object | 2D array (4,4) - Reference space where streamlines live in `fileobj`. - - lazy_load : boolean + lazy_load : boolean (optional) Load streamlines in a lazy manner i.e. they will not be kept in memory. For postprocessing speed, turn off this option. Returns ------- - streamlines : Tractogram object - Returns an object containing streamlines' data and header - information. See 'nibabel.Tractogram'. + tractogram_file : ``TractogramFile`` object + Returns an object containing tractogram data and header + information. ''' raise NotImplementedError() - @staticmethod - def save(streamlines, fileobj, ref=None): + @abstractmethod + def save(self, fileobj): ''' Saves streamlines to a file-like object. Parameters ---------- - streamlines : Tractogram object - Object containing streamlines' data and header information. - See 'nibabel.Tractogram'. - fileobj : string or file-like object If string, a filename; otherwise an open file-like object opened and ready to write. - - ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) - Reference space where streamlines will live in `fileobj`. - ''' - raise NotImplementedError() - - @staticmethod - def pretty_print(streamlines): - ''' Gets a formatted string of the header of a streamlines file format. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the - beginning of the header). - - Returns - ------- - info : string - Header information relevant to the streamlines file format. ''' raise NotImplementedError() diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 5cd30c65f6..e78be90429 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -12,7 +12,7 @@ from nibabel.testing import assert_arrays_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true, assert_false -from ..base_format import Tractogram, LazyTractogram +from ..base_format import Tractogram, LazyTractogram, TractogramFile from ..base_format import HeaderError, UsageWarning from ..header import Field from .. import trk @@ -134,28 +134,32 @@ def setUp(self): def test_load_empty_file(self): for empty_filename in self.empty_filenames: - tractogram = nib.streamlines.load(empty_filename, - ref=self.to_world_space, - lazy_load=False) - assert_true(type(tractogram), Tractogram) - check_tractogram(tractogram, 0, [], [], []) + tractogram_file = nib.streamlines.load(empty_filename, + lazy_load=False, + ref=self.to_world_space) + assert_true(isinstance(tractogram_file, TractogramFile)) + assert_true(type(tractogram_file.tractogram), Tractogram) + check_tractogram(tractogram_file.tractogram, 0, [], [], []) def test_load_simple_file(self): for simple_filename in self.simple_filenames: - tractogram = nib.streamlines.load(simple_filename, - ref=self.to_world_space, - lazy_load=False) - assert_true(type(tractogram), Tractogram) - check_tractogram(tractogram, self.nb_streamlines, - self.streamlines, [], []) + tractogram_file = nib.streamlines.load(simple_filename, + lazy_load=False, + ref=self.to_world_space) + assert_true(isinstance(tractogram_file, TractogramFile)) + assert_true(type(tractogram_file.tractogram), Tractogram) + check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + self.streamlines, [], []) # Test lazy_load - tractogram = nib.streamlines.load(simple_filename, - ref=self.to_world_space, - lazy_load=True) - assert_true(type(tractogram), LazyTractogram) - check_tractogram(tractogram, self.nb_streamlines, - self.streamlines, [], []) + tractogram_file = nib.streamlines.load(simple_filename, + lazy_load=True, + ref=self.to_world_space) + + assert_true(isinstance(tractogram_file, TractogramFile)) + assert_true(type(tractogram_file.tractogram), LazyTractogram) + check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + self.streamlines, [], []) def test_load_complex_file(self): for complex_filename in self.complex_filenames: @@ -169,20 +173,22 @@ def test_load_complex_file(self): if file_format.can_save_properties(): properties = self.mean_curvature_torsion - tractogram = nib.streamlines.load(complex_filename, - ref=self.to_world_space, - lazy_load=False) - assert_true(type(tractogram), Tractogram) - check_tractogram(tractogram, self.nb_streamlines, - self.streamlines, scalars, properties) + tractogram_file = nib.streamlines.load(complex_filename, + lazy_load=False, + ref=self.to_world_space) + assert_true(isinstance(tractogram_file, TractogramFile)) + assert_true(type(tractogram_file.tractogram), Tractogram) + check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + self.streamlines, scalars, properties) # Test lazy_load - tractogram = nib.streamlines.load(complex_filename, - ref=self.to_world_space, - lazy_load=True) - assert_true(type(tractogram), LazyTractogram) - check_tractogram(tractogram, self.nb_streamlines, - self.streamlines, scalars, properties) + tractogram_file = nib.streamlines.load(complex_filename, + lazy_load=True, + ref=self.to_world_space) + assert_true(isinstance(tractogram_file, TractogramFile)) + assert_true(type(tractogram_file.tractogram), LazyTractogram) + check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + self.streamlines, scalars, properties) def test_save_simple_file(self): tractogram = Tractogram(self.streamlines) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 40720027fc..e1584731db 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -12,7 +12,7 @@ from ..base_format import Tractogram, LazyTractogram from ..base_format import DataError, HeaderError, HeaderWarning, UsageWarning -from .. import trk +#from .. import trk from ..trk import TrkFile DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') @@ -57,30 +57,31 @@ def setUp(self): self.nb_streamlines = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) + self.affine = np.eye(4) def test_load_empty_file(self): trk = TrkFile.load(self.empty_trk_filename, ref=None, lazy_load=False) - check_tractogram(trk, 0, [], [], []) + check_tractogram(trk.tractogram, 0, [], [], []) trk = TrkFile.load(self.empty_trk_filename, ref=None, lazy_load=True) # Suppress warning about loading a TRK file in lazy mode with count=0. with suppress_warnings(): - check_tractogram(trk, 0, [], [], []) + check_tractogram(trk.tractogram, 0, [], [], []) def test_load_simple_file(self): trk = TrkFile.load(self.simple_trk_filename, ref=None, lazy_load=False) - check_tractogram(trk, self.nb_streamlines, self.streamlines, [], []) + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) trk = TrkFile.load(self.simple_trk_filename, ref=None, lazy_load=True) - check_tractogram(trk, self.nb_streamlines, self.streamlines, [], []) + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) def test_load_complex_file(self): trk = TrkFile.load(self.complex_trk_filename, ref=None, lazy_load=False) - check_tractogram(trk, self.nb_streamlines, + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, self.colors, self.mean_curvature_torsion) trk = TrkFile.load(self.complex_trk_filename, ref=None, lazy_load=True) - check_tractogram(trk, self.nb_streamlines, + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, self.colors, self.mean_curvature_torsion) def test_load_file_with_wrong_information(self): @@ -89,12 +90,12 @@ def test_load_file_with_wrong_information(self): # Simulate a TRK file where `count` was not provided. count = np.array(0, dtype="int32").tostring() new_trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] - tractogram = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) - check_tractogram(tractogram, self.nb_streamlines, self.streamlines, [], []) + trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) - tractogram = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) + trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) with clear_and_catch_warnings(record=True, modules=[base_format]) as w: - check_tractogram(tractogram, self.nb_streamlines, self.streamlines, [], []) + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) assert_equal(len(w), 1) assert_true(issubclass(w[0].category, UsageWarning)) @@ -121,15 +122,16 @@ def test_write_simple_file(self): tractogram = Tractogram(self.streamlines) trk_file = BytesIO() - TrkFile.save(tractogram, trk_file) + trk = TrkFile(tractogram, ref=self.affine) + trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_tractogram = TrkFile.load(trk_file) - check_tractogram(loaded_tractogram, self.nb_streamlines, - self.streamlines, [], []) + loaded_trk = TrkFile.load(trk_file) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, + self.streamlines, [], []) - loaded_tractogram_orig = TrkFile.load(self.simple_trk_filename) - assert_tractogram_equal(loaded_tractogram, loaded_tractogram_orig) + loaded_trk_orig = TrkFile.load(self.simple_trk_filename) + assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(open(self.simple_trk_filename, 'rb').read(), trk_file.read()) @@ -139,38 +141,40 @@ def test_write_complex_file(self): tractogram = Tractogram(self.streamlines, scalars=self.colors) trk_file = BytesIO() - TrkFile.save(tractogram, trk_file) + trk = TrkFile(tractogram, ref=self.affine) + trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_tractogram = TrkFile.load(trk_file, ref=None, lazy_load=False) - - check_tractogram(loaded_tractogram, self.nb_streamlines, - self.streamlines, self.colors, []) + loaded_trk = TrkFile.load(trk_file, lazy_load=False) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, + self.streamlines, self.colors, []) # With properties tractogram = Tractogram(self.streamlines, properties=self.mean_curvature_torsion) trk_file = BytesIO() - TrkFile.save(tractogram, trk_file) + trk = TrkFile(tractogram, ref=self.affine) + trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_tractogram = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_tractogram(loaded_tractogram, self.nb_streamlines, - self.streamlines, [], self.mean_curvature_torsion) + loaded_trk = TrkFile.load(trk_file, lazy_load=False) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, + self.streamlines, [], self.mean_curvature_torsion) # With scalars and properties tractogram = Tractogram(self.streamlines, scalars=self.colors, properties=self.mean_curvature_torsion) trk_file = BytesIO() - TrkFile.save(tractogram, trk_file) + trk = TrkFile(tractogram, ref=self.affine) + trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_tractogram = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_tractogram(loaded_tractogram, self.nb_streamlines, - self.streamlines, self.colors, self.mean_curvature_torsion) + loaded_trk = TrkFile.load(trk_file, lazy_load=False) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, + self.streamlines, self.colors, self.mean_curvature_torsion) - loaded_tractogram_orig = TrkFile.load(self.complex_trk_filename) - assert_tractogram_equal(loaded_tractogram, loaded_tractogram_orig) + loaded_trk_orig = TrkFile.load(self.complex_trk_filename) + assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(open(self.complex_trk_filename, 'rb').read(), trk_file.read()) @@ -182,14 +186,16 @@ def test_write_erroneous_file(self): [(0, 0, 1)]] tractogram = Tractogram(self.streamlines, scalars) - assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) + trk = TrkFile(tractogram, ref=self.affine) + assert_raises(DataError, trk.save, BytesIO()) # No scalars for every streamlines scalars = [[(1, 0, 0)]*1, [(0, 1, 0)]*2] tractogram = Tractogram(self.streamlines, scalars) - assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) + trk = TrkFile(tractogram, ref=self.affine) + assert_raises(DataError, trk.save, BytesIO()) # # Unit test moved to test_base_format.py # # Inconsistent number of scalars between points @@ -215,29 +221,32 @@ def test_write_erroneous_file(self): np.array([2.11], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] tractogram = Tractogram(self.streamlines, properties=properties) - assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) + trk = TrkFile(tractogram, ref=self.affine) + assert_raises(DataError, trk.save, BytesIO()) # Unit test moved to test_base_format.py # No properties for every streamlines properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11, 2.22], dtype="f4")] tractogram = Tractogram(self.streamlines, properties=properties) - assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) - - def test_write_file_lazy_tractogram(self): - streamlines = lambda: (point for point in self.streamlines) - scalars = lambda: (scalar for scalar in self.colors) - properties = lambda: (prop for prop in self.mean_curvature_torsion) - - tractogram = LazyTractogram(streamlines, scalars, properties) - # No need to manually set `nb_streamlines` in the header since we count - # them as writing. - #tractogram.header.nb_streamlines = self.nb_streamlines - - trk_file = BytesIO() - TrkFile.save(tractogram, trk_file) - trk_file.seek(0, os.SEEK_SET) - - trk = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_tractogram(trk, self.nb_streamlines, - self.streamlines, self.colors, self.mean_curvature_torsion) + trk = TrkFile(tractogram, ref=self.affine) + assert_raises(DataError, trk.save, BytesIO()) + + # def test_write_file_lazy_tractogram(self): + # streamlines = lambda: (point for point in self.streamlines) + # scalars = lambda: (scalar for scalar in self.colors) + # properties = lambda: (prop for prop in self.mean_curvature_torsion) + + # tractogram = LazyTractogram(streamlines, scalars, properties) + # # No need to manually set `nb_streamlines` in the header since we count + # # them as writing. + # #tractogram.header.nb_streamlines = self.nb_streamlines + + # trk_file = BytesIO() + # trk = TrkFile(tractogram, ref=self.affine) + # trk.save(trk_file) + # trk_file.seek(0, os.SEEK_SET) + + # trk = TrkFile.load(trk_file, ref=None, lazy_load=False) + # check_tractogram(trk.tractogram, self.nb_streamlines, + # self.streamlines, self.colors, self.mean_curvature_torsion) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index ae913ad67f..1042e5f9af 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -8,11 +8,11 @@ import warnings import numpy as np +import nibabel as nib from nibabel.openers import Opener from nibabel.volumeutils import (native_code, swapped_code) -from nibabel.streamlines.base_format import CompactList from nibabel.streamlines.base_format import TractogramFile from nibabel.streamlines.base_format import DataError, HeaderError, HeaderWarning from nibabel.streamlines.base_format import Tractogram, LazyTractogram @@ -319,6 +319,27 @@ class TrkFile(TractogramFile): MAGIC_NUMBER = b"TRACK" HEADER_SIZE = 1000 + def __init__(self, tractogram, ref, header=None): + """ + Parameters + ---------- + tractogram : ``Tractogram`` object + Tractogram that will be contained in this ``TrkFile``. + + ref : filename | `Nifti1Image` object | 2D array (4,4) + Reference space where streamlines live in. + + header : ``TractogramHeader`` file + Metadata associated to this tractogram file. + + Notes + ----- + Streamlines of the tractogram are assumed to be in *RAS+* and *mm* space + where coordinate (0,0,0) refers to the center of the voxel. + """ + super(TrkFile, self).__init__(tractogram, header) + self._affine = get_affine_from_reference(ref) + @classmethod def get_magic_number(cls): ''' Return TRK's magic number. ''' @@ -357,8 +378,8 @@ def is_correct_format(cls, fileobj): return False - @staticmethod - def load(fileobj, ref=None, lazy_load=False): + @classmethod + def load(cls, fileobj, lazy_load=False, ref=None): ''' Loads streamlines from a file-like object. Parameters @@ -368,30 +389,44 @@ def load(fileobj, ref=None, lazy_load=False): pointing to TRK file (and ready to read from the beginning of the TRK header). - ref : filename | `Nifti1Image` object | 2D array (4,4) | None - Reference space where streamlines live in `fileobj`. - lazy_load : boolean (optional) Load streamlines in a lazy manner i.e. they will not be kept in memory. + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines live in `fileobj`. + Returns ------- - tractogram : Tractogram object - Returns an object containing tractogram' data and header - information. See `nibabel.Tractogram`. + trk_file : ``TrkFile`` object + Returns an object containing tractogram data and header + information. Notes ----- - Tractogram are assumed to be in voxel space where coordinate (0,0,0) - refers to the center of the voxel. + Streamlines of the returned tractogram are assumed to be in RASmm + space where coordinate (0,0,0) refers to the center of the voxel. ''' trk_reader = TrkReader(fileobj) # TRK's streamlines are in 'voxelmm' space, we send them to rasmm. - affine = trk_reader.header[Field.to_world_space] + # First send them to voxel space. + affine = np.eye(4) affine[range(3), range(3)] /= trk_reader.header[Field.VOXEL_SIZES] + # If voxel order implied from the affine does not match the voxel + # order save in the TRK header, change the orientation. + header_ornt = trk_reader.header[Field.VOXEL_ORDER] + affine_ornt = "".join(nib.orientations.aff2axcodes(trk_reader.header[Field.to_world_space])) + header_ornt = nib.orientations.axcodes2ornt(header_ornt) + affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) + ornt = nib.orientations.ornt_transform(header_ornt, affine_ornt) + M = nib.orientations.inv_ornt_aff(ornt, trk_reader.header[Field.DIMENSIONS]) + affine = np.dot(M, affine) + + # Applied the affine going from voxel space to rasmm. + affine = np.dot(trk_reader.header[Field.to_world_space], affine) + # TrackVis considers coordinate (0,0,0) to be the corner of the # voxel whereas streamlines returned assume (0,0,0) to be the # center of the voxel. Thus, streamlines are shifted of half @@ -494,36 +529,29 @@ def _apply_transform(trk_reader): #if tractogram.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: # raise HeaderError("'nb_properties_per_streamline' does not match.") - return tractogram + return cls(tractogram, ref=affine, header=trk_reader.header) - @staticmethod - def save(tractogram, fileobj, ref=None): - ''' Saves tractogram to a file-like object. + def save(self, fileobj): + ''' Saves tractogram to a file-like object using TRK format. Parameters ---------- - tractogram : Tractogram object - Object containing tractogram' data and header information. - See 'nibabel.Tractogram'. - fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to TRK file (and ready to read from the beginning of the TRK header data). - - ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) - Reference space where streamlines will live in `fileobj`. - - Notes - ----- - Tractogram are assumed to be in voxel space where coordinate (0,0,0) - refers to the center of the voxel. ''' - if ref is not None: - tractogram.header.to_world_space = get_affine_from_reference(ref) + # Update header using the tractogram. + self.header.nb_scalars_per_point = 0 + if self.tractogram.scalars.shape is not None: + self.header.nb_scalars_per_point = len(self.tractogram.scalars[0]) + + self.header.nb_properties_per_streamline = 0 + if self.tractogram.properties.shape is not None: + self.header.nb_properties_per_streamline = len(self.tractogram.properties[0]) - trk_writer = TrkWriter(fileobj, tractogram.header) - trk_writer.write(tractogram) + trk_writer = TrkWriter(fileobj, self.header) + trk_writer.write(self.tractogram) @staticmethod def pretty_print(fileobj): From 9a66d04660ea38d9afaef9395e65727af42ee3c7 Mon Sep 17 00:00:00 2001 From: Eleftherios Garyfallidis Date: Mon, 2 Nov 2015 17:09:54 -0500 Subject: [PATCH 17/54] NF: Tractogram is now more specific and supports multiple keys for scalars and properties which are now called data_per_steamlines and data_per_points respectively --- nibabel/streamlines/base_format.py | 226 +++++------------- nibabel/streamlines/tests/test_base_format.py | 15 ++ 2 files changed, 70 insertions(+), 171 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 42e34d0a37..bf9cea59a7 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -194,10 +194,10 @@ def __repr__(self): class TractogramItem(object): - ''' Class containing information about one streamline. + """ Class containing information about one streamline. - ``TractogramItem`` objects have three main properties: `streamline`, `scalars` - and ``properties``. + ``TractogramItem`` objects have three main properties: `streamline`, + `data_for_streamline`, and `data_for_points`. Parameters ---------- @@ -205,22 +205,16 @@ class TractogramItem(object): Points of this streamline represented as an ndarray of shape (N, 3) where N is the number of points. - scalars : ndarray of shape (N, M) - Scalars associated with each point of this streamline and represented - as an ndarray of shape (N, M) where N is the number of points and - M is the number of scalars (excluding the three coordinates). + data_for_streamline : dict - properties : ndarray of shape (P,) - Properties associated with this streamline and represented as an - ndarray of shape (P,) where P is the number of properties. - ''' - def __init__(self, streamline, scalars=None, properties=None): - #if scalars is not None and len(streamline) != len(scalars): - # raise ValueError("First dimension of streamline and scalars must match.") + data_for_points : dict + """ + def __init__(self, streamline, + data_for_streamline, data_for_points): self.streamline = np.asarray(streamline) - self.scalars = np.asarray([] if scalars is None else scalars) - self.properties = np.asarray([] if properties is None else properties) + self.data_for_streamline = data_for_streamline + self.data_for_points = data_for_points def __iter__(self): return iter(self.streamline) @@ -230,187 +224,77 @@ def __len__(self): class Tractogram(object): - ''' Class containing information about streamlines. + """ Class containing information about streamlines. - Tractogram objects have three main properties: ``streamlines``, ``scalars`` - and ``properties``. Tractogram objects can be iterate over producing - tuple of ``streamlines``, ``scalars`` and ``properties`` for each streamline. + Tractogram objects have three main properties: ``streamlines`` Parameters ---------- streamlines : list of ndarray of shape (Nt, 3) - Sequence of T streamlines. One streamline is an ndarray of shape (Nt, 3) - where Nt is the number of points of streamline t. + Sequence of T streamlines. One streamline is an ndarray of shape + (Nt, 3) where Nt is the number of points of streamline t. + + data_per_streamline : dictionary of list of ndarray of shape (P,) + Sequence of T ndarrays of shape (P,) where T is the number of + streamlines defined by ``streamlines``, P is the number of properties + associated to each streamline. - scalars : list of ndarray of shape (Nt, M) + data_per_point : dictionary of list of ndarray of shape (Nt, M) Sequence of T ndarrays of shape (Nt, M) where T is the number of streamlines defined by ``streamlines``, Nt is the number of points for a particular streamline t and M is the number of scalars associated to each point (excluding the three coordinates). - properties : list of ndarray of shape (P,) - Sequence of T ndarrays of shape (P,) where T is the number of - streamlines defined by ``streamlines``, P is the number of properties - associated to each streamline. - ''' - def __init__(self, streamlines=None, scalars=None, properties=None): - self._header = TractogramHeader() - self.streamlines = streamlines - self.scalars = scalars - self.properties = properties + """ + def __init__(self, streamlines=None, + data_per_streamline=None, + data_per_point=None): - @classmethod - def create_from_generator(cls, gen): - BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. - - streamlines = CompactList() - scalars = CompactList() - properties = np.array([]) - - gen = iter(gen) - try: - first_element = next(gen) - gen = itertools.chain([first_element], gen) - except StopIteration: - return cls(streamlines, scalars, properties) - - # Allocated some buffer memory. - pts = np.asarray(first_element[0]) - scals = np.asarray(first_element[1]) - props = np.asarray(first_element[2]) - - scals_shape = scals.shape - props_shape = props.shape - - streamlines._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) - scalars._data = np.empty((BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) - properties = np.empty((BUFFER_SIZE, props.shape[0]), dtype=props.dtype) - - offset = 0 - for i, (pts, scals, props) in enumerate(gen): - pts = np.asarray(pts) - scals = np.asarray(scals) - props = np.asarray(props) - - if scals.shape[1] != scals_shape[1]: - raise ValueError("Number of scalars differs from one" - " point or streamline to another") - - if props.shape != props_shape: - raise ValueError("Number of properties differs from one" - " streamline to another") - - end = offset + len(pts) - if end >= len(streamlines._data): - # Resize is needed (at least `len(pts)` items will be added). - streamlines._data.resize((len(streamlines._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) - scalars._data.resize((len(scalars._data) + len(scals)+BUFFER_SIZE, scals.shape[1])) - - streamlines._offsets.append(offset) - streamlines._lengths.append(len(pts)) - streamlines._data[offset:offset+len(pts)] = pts - scalars._data[offset:offset+len(scals)] = scals - - offset += len(pts) - - if i >= len(properties): - properties.resize((len(properties) + BUFFER_SIZE, props.shape[0])) - - properties[i] = props - - # Clear unused memory. - streamlines._data.resize((offset, pts.shape[1])) - - if scals_shape[1] == 0: - # Because resizing an empty ndarray creates memory! - scalars._data = np.empty((offset, scals.shape[1])) + self.streamlines = streamlines + if data_per_streamline is None: + self.data_per_streamline = {} else: - scalars._data.resize((offset, scals.shape[1])) - - # Share offsets and lengths between streamlines and scalars. - scalars._offsets = streamlines._offsets - scalars._lengths = streamlines._lengths - - if props_shape[0] == 0: - # Because resizing an empty ndarray creates memory! - properties = np.empty((i+1, props.shape[0])) + self.data_per_streamline = data_per_streamline + if data_per_point is None: + self.data_per_point = {} else: - properties.resize((i+1, props.shape[0])) - - return cls(streamlines, scalars, properties) - - - @property - def header(self): - return self._header - - @property - def streamlines(self): - return self._streamlines - - @streamlines.setter - def streamlines(self, value): - self._streamlines = value - if not isinstance(value, CompactList): - self._streamlines = CompactList(value) - - self.header.nb_streamlines = len(self.streamlines) - - @property - def scalars(self): - return self._scalars - - @scalars.setter - def scalars(self, value): - self._scalars = value - if not isinstance(value, CompactList): - self._scalars = CompactList(value) - - self.header.nb_scalars_per_point = 0 - if len(self.scalars) > 0 and len(self.scalars[0]) > 0: - self.header.nb_scalars_per_point = len(self.scalars[0][0]) - - @property - def properties(self): - return self._properties - - @properties.setter - def properties(self, value): - self._properties = np.asarray(value) - if value is None: - self._properties = np.empty((len(self), 0), dtype=np.float32) - - self.header.nb_properties_per_streamline = 0 - if len(self.properties) > 0: - self.header.nb_properties_per_streamline = len(self.properties[0]) + self.data_per_point = data_per_point def __iter__(self): - for data in zip_longest(self.streamlines, self.scalars, self.properties, fillvalue=None): - yield TractogramItem(*data) + for i in range(len(self.streamlines)): + yield self[i] def __getitem__(self, idx): pts = self.streamlines[idx] - scalars = [] - if len(self.scalars) > 0: - scalars = self.scalars[idx] - properties = [] - if len(self.properties) > 0: - properties = self.properties[idx] + new_data_per_streamline = {} + for key in self.data_per_streamline: + new_data_per_streamline[key] = self.data_per_streamline[key][idx] + + new_data_per_point = {} + for key in self.data_per_point: + new_data_per_point[key] = self.data_per_point[key][idx] if type(idx) is slice: - return Tractogram(pts, scalars, properties) + return Tractogram(pts, new_data_per_streamline, new_data_per_point) - return TractogramItem(pts, scalars, properties) - - def __len__(self): - return len(self.streamlines) + return TractogramItem(pts, new_data_per_streamline, new_data_per_point) def copy(self): """ Returns a copy of this `Tractogram` object. """ - streamlines = Tractogram(self.streamlines.copy(), self.scalars.copy(), self.properties.copy()) - streamlines._header = self.header.copy() - return streamlines + + new_data_per_streamline = {} + for key in self.data_per_streamline: + new_data_per_streamline[key] = self.data_per_streamline[key].copy() + + new_data_per_point = {} + for key in self.data_per_point: + new_data_per_point[key] = self.data_per_point[key].copy() + + tractogram = Tractogram(self.streamlines.copy(), + new_data_per_streamline, + new_data_per_point) + return tractogram def apply_affine(self, affine): """ Applies an affine transformation on the points of each streamline. diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index 99d8c8ea90..ad22274012 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -251,10 +251,25 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] + self.mean_curvature = np.array([1.11, 2.11, 3.11], dtype="f4") + self.mean_color = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype="f4") + + self.nb_tractogram = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) + def test_tractogram_creation_with_dix(self): + + tractogram = Tractogram( + streamlines=self.streamlines, + data_per_streamline={'mean_curvature': self.mean_curvature, + 'mean_color': self.mean_color}, + data_per_point={'colors': self.colors}) + + tractogram[:2] + + def test_tractogram_creation_from_arrays(self): # Empty tractogram = Tractogram() From b6e9cf0cec215bf32c035d5e8180c4e4a10e34a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 2 Nov 2015 18:03:37 -0500 Subject: [PATCH 18/54] Fixed some unit tests for Tractogram and TractogramItem --- nibabel/streamlines/base_format.py | 47 +++- nibabel/streamlines/tests/test_base_format.py | 247 +++++++----------- 2 files changed, 133 insertions(+), 161 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index bf9cea59a7..52563a5a97 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -209,9 +209,7 @@ class TractogramItem(object): data_for_points : dict """ - def __init__(self, streamline, - data_for_streamline, data_for_points): - + def __init__(self, streamline, data_for_streamline, data_for_points): self.streamline = np.asarray(streamline) self.data_for_streamline = data_for_streamline self.data_for_points = data_for_points @@ -251,15 +249,43 @@ def __init__(self, streamlines=None, data_per_point=None): self.streamlines = streamlines - if data_per_streamline is None: - self.data_per_streamline = {} - else: + + self.data_per_streamline = {} + if data_per_streamline is not None: self.data_per_streamline = data_per_streamline - if data_per_point is None: - self.data_per_point = {} - else: + + self.data_per_point = {} + if data_per_point is not None: self.data_per_point = data_per_point + @property + def streamlines(self): + return self._streamlines + + @streamlines.setter + def streamlines(self, value): + self._streamlines = CompactList(value) + + @property + def data_per_streamline(self): + return self._data_per_streamline + + @data_per_streamline.setter + def data_per_streamline(self, value): + self._data_per_streamline = {} + for k, v in value.items(): + self._data_per_streamline[k] = np.asarray(v) + + @property + def data_per_point(self): + return self._data_per_point + + @data_per_point.setter + def data_per_point(self, value): + self._data_per_point = {} + for k, v in value.items(): + self._data_per_point[k] = CompactList(v) + def __iter__(self): for i in range(len(self.streamlines)): yield self[i] @@ -280,6 +306,9 @@ def __getitem__(self, idx): return TractogramItem(pts, new_data_per_streamline, new_data_per_point) + def __len__(self): + return len(self.streamlines) + def copy(self): """ Returns a copy of this `Tractogram` object. """ diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index ad22274012..5ae8c0d5d8 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -3,7 +3,7 @@ import numpy as np import warnings -from nibabel.testing import assert_arrays_equal +from nibabel.testing import assert_arrays_equal, isiterable from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal, assert_array_almost_equal @@ -209,36 +209,31 @@ class TestTractogramItem(unittest.TestCase): def test_creating_tractogram_item(self): rng = np.random.RandomState(42) streamline = rng.rand(rng.randint(10, 50), 3) - scalars = rng.rand(len(streamline), 5) - properties = rng.rand(42) + colors = rng.rand(len(streamline), 3) + mean_curvature = 1.11 + mean_color = np.array([0, 1, 0], dtype="f4") - # Create a streamline with only points - t = TractogramItem(streamline) - assert_equal(len(t), len(streamline)) - assert_array_equal(t.scalars, []) - assert_array_equal(t.properties, []) + data_for_streamline = {"mean_curvature": mean_curvature, + "mean_color": mean_color} + + data_for_points = {"colors": colors} - # Create a streamline with points, scalars and properties. - t = TractogramItem(streamline, scalars, properties) + # Create a tractogram item with a streamline, data. + t = TractogramItem(streamline, data_for_streamline, data_for_points) assert_equal(len(t), len(streamline)) assert_array_equal(t.streamline, streamline) assert_array_equal(list(t), streamline) - assert_equal(len(t), len(scalars)) - assert_array_equal(t.scalars, scalars) - assert_array_equal(t.properties, properties) - - # # Create a streamline with different number of scalars. - # scalars = rng.rand(len(streamline)+3, 5) - # assert_raises(ValueError, TractogramItem, streamline, scalars) + assert_array_equal(t.data_for_streamline['mean_curvature'], + mean_curvature) + assert_array_equal(t.data_for_streamline['mean_color'], + mean_color) + assert_array_equal(t.data_for_points['colors'], + colors) class TestTractogram(unittest.TestCase): def setUp(self): - self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") - self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") - self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") - self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), np.arange(2*3, dtype="f4").reshape((2, 3)), np.arange(5*3, dtype="f4").reshape((5, 3))] @@ -247,77 +242,56 @@ def setUp(self): np.array([(0, 1, 0)]*2, dtype="f4"), np.array([(0, 0, 1)]*5, dtype="f4")] - self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), - np.array([2.11, 2.22], dtype="f4"), - np.array([3.11, 3.22], dtype="f4")] - self.mean_curvature = np.array([1.11, 2.11, 3.11], dtype="f4") - self.mean_color = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype="f4") - + self.mean_color = np.array([[0, 1, 0], + [0, 0, 1], + [1, 0, 0]], dtype="f4") self.nb_tractogram = len(self.streamlines) - self.nb_scalars_per_point = self.colors[0].shape[1] - self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) - - def test_tractogram_creation_with_dix(self): - - tractogram = Tractogram( - streamlines=self.streamlines, - data_per_streamline={'mean_curvature': self.mean_curvature, - 'mean_color': self.mean_color}, - data_per_point={'colors': self.colors}) - - tractogram[:2] - - def test_tractogram_creation_from_arrays(self): - # Empty + def test_tractogram_creation(self): + # Create an empty tractogram. tractogram = Tractogram() assert_equal(len(tractogram), 0) assert_arrays_equal(tractogram.streamlines, []) - assert_arrays_equal(tractogram.scalars, []) - assert_arrays_equal(tractogram.properties, []) + assert_equal(tractogram.data_per_streamline, {}) + assert_equal(tractogram.data_per_point, {}) + assert_true(isiterable(tractogram)) - # Check if we can iterate through the tractogram. - for streamline in tractogram: - pass - - # Only streamlines + # Create a tractogram with only streamlines tractogram = Tractogram(streamlines=self.streamlines) assert_equal(len(tractogram), len(self.streamlines)) assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.scalars, []) - assert_arrays_equal(tractogram.properties, []) + assert_equal(tractogram.data_per_streamline, {}) + assert_equal(tractogram.data_per_point, {}) + assert_true(isiterable(tractogram)) - # Check if we can iterate through the tractogram. - for streamline in tractogram: - pass + # Create a tractogram with streamlines and other data. + tractogram = Tractogram( + self.streamlines, + data_per_streamline={'mean_curvature': self.mean_curvature, + 'mean_color': self.mean_color}, + data_per_point={'colors': self.colors}) - # Points, scalars and properties - tractogram = Tractogram(self.streamlines, self.colors, self.mean_curvature_torsion) assert_equal(len(tractogram), len(self.streamlines)) assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.scalars, self.colors) - assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) + assert_arrays_equal(tractogram.data_per_streamline['mean_curvature'], + self.mean_curvature) + assert_arrays_equal(tractogram.data_per_streamline['mean_color'], + self.mean_color) + assert_arrays_equal(tractogram.data_per_point['colors'], + self.colors) - # Check if we can iterate through the tractogram. - for streamline in tractogram: - pass + assert_true(isiterable(tractogram)) # Inconsistent number of scalars between streamlines - scalars = [[(1, 0, 0)]*1, - [(0, 1, 0), (0, 1)], - [(0, 0, 1)]*5] - - assert_raises(ValueError, Tractogram, self.streamlines, scalars) - - # Unit test moved to test_base_format.py - # Inconsistent number of scalars between tractogram - scalars = [[(1, 0, 0)]*1, - [(0, 1)]*2, - [(0, 0, 1)]*5] + wrong_data = [[(1, 0, 0)]*1, + [(0, 1, 0), (0, 1)], + [(0, 0, 1)]*5] - assert_raises(ValueError, Tractogram, self.streamlines, scalars) + data_per_point = {'wrong_data': wrong_data} + assert_raises(ValueError, Tractogram, self.streamlines, + data_per_point=data_per_point) def test_tractogram_getter(self): # Tractogram with only streamlines @@ -326,78 +300,71 @@ def test_tractogram_getter(self): selected_tractogram = tractogram[::2] assert_equal(len(selected_tractogram), (len(self.streamlines)+1)//2) - assert_arrays_equal(selected_tractogram.streamlines, self.streamlines[::2]) - assert_equal(sum(map(len, selected_tractogram.scalars)), 0) - assert_equal(sum(map(len, selected_tractogram.properties)), 0) + assert_arrays_equal(selected_tractogram.streamlines, + self.streamlines[::2]) + assert_equal(tractogram.data_per_streamline, {}) + assert_equal(tractogram.data_per_point, {}) - # Tractogram with streamlines, scalars and properties - tractogram = Tractogram(self.streamlines, self.colors, self.mean_curvature_torsion) + # Create a tractogram with streamlines and other data. + tractogram = Tractogram( + self.streamlines, + data_per_streamline={'mean_curvature': self.mean_curvature, + 'mean_color': self.mean_color}, + data_per_point={'colors': self.colors}) # Retrieve tractogram by their index for i, t in enumerate(tractogram): assert_array_equal(t.streamline, tractogram[i].streamline) - assert_array_equal(t.scalars, tractogram[i].scalars) - assert_array_equal(t.properties, tractogram[i].properties) + assert_array_equal(t.data_for_points['colors'], + tractogram[i].data_for_points['colors']) + + assert_array_equal(t.data_for_streamline['mean_curvature'], + tractogram[i].data_for_streamline['mean_curvature']) + + assert_array_equal(t.data_for_streamline['mean_color'], + tractogram[i].data_for_streamline['mean_color']) # Use slicing r_tractogram = tractogram[::-1] assert_arrays_equal(r_tractogram.streamlines, self.streamlines[::-1]) - assert_arrays_equal(r_tractogram.scalars, self.colors[::-1]) - assert_arrays_equal(r_tractogram.properties, self.mean_curvature_torsion[::-1]) - def test_tractogram_creation_from_generator(self): - # Create `Tractogram` from a generator yielding 3-tuples - gen = (x for x in zip(self.streamlines, self.colors, self.mean_curvature_torsion)) - - tractogram = Tractogram.create_from_generator(gen) - with suppress_warnings(): - assert_equal(len(tractogram), self.nb_tractogram) + assert_arrays_equal(r_tractogram.data_per_streamline['mean_curvature'], + self.mean_curvature[::-1]) + assert_arrays_equal(r_tractogram.data_per_streamline['mean_color'], + self.mean_color[::-1]) + assert_arrays_equal(r_tractogram.data_per_point['colors'], + self.colors[::-1]) - assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.scalars, self.colors) - assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) - - # Check if we can iterate through the tractogram. - for streamline in tractogram: - pass - - def test_tractogram_creation_from_coroutines(self): - # Points, scalars and properties - streamlines = lambda: (x for x in self.streamlines) - scalars = lambda: (x for x in self.colors) - properties = lambda: (x for x in self.mean_curvature_torsion) - - # To create tractogram from multiple coroutines use `LazyTractogram`. - assert_raises(TypeError, Tractogram, streamlines, scalars, properties) + def test_tractogram_add_new_data(self): + # Tractogram with only streamlines + tractogram = Tractogram(streamlines=self.streamlines) - def test_header(self): - # Empty Tractogram, with default header - tractogram = Tractogram() - assert_equal(tractogram.header.nb_streamlines, 0) - assert_equal(tractogram.header.nb_scalars_per_point, 0) - assert_equal(tractogram.header.nb_properties_per_streamline, 0) - assert_array_equal(tractogram.header.voxel_sizes, (1, 1, 1)) - assert_array_equal(tractogram.header.to_world_space, np.eye(4)) - assert_equal(tractogram.header.extra, {}) + tractogram.data_per_streamline['mean_curvature'] = self.mean_curvature + tractogram.data_per_streamline['mean_color'] = self.mean_color + tractogram.data_per_point['colors'] = self.colors - tractogram = Tractogram(self.streamlines, self.colors, self.mean_curvature_torsion) + # Retrieve tractogram by their index + for i, t in enumerate(tractogram): + assert_array_equal(t.streamline, tractogram[i].streamline) + assert_array_equal(t.data_for_points['colors'], + tractogram[i].data_for_points['colors']) - assert_equal(tractogram.header.nb_streamlines, len(self.streamlines)) - assert_equal(tractogram.header.nb_scalars_per_point, self.colors[0].shape[1]) - assert_equal(tractogram.header.nb_properties_per_streamline, self.mean_curvature_torsion[0].shape[0]) + assert_array_equal(t.data_for_streamline['mean_curvature'], + tractogram[i].data_for_streamline['mean_curvature']) - # Modifying voxel_sizes should be reflected in to_world_space - tractogram.header.voxel_sizes = (2, 3, 4) - assert_array_equal(tractogram.header.voxel_sizes, (2, 3, 4)) - assert_array_equal(np.diag(tractogram.header.to_world_space), (2, 3, 4, 1)) + assert_array_equal(t.data_for_streamline['mean_color'], + tractogram[i].data_for_streamline['mean_color']) - # Modifying scaling of to_world_space should be reflected in voxel_sizes - tractogram.header.to_world_space = np.diag([4, 3, 2, 1]) - assert_array_equal(tractogram.header.voxel_sizes, (4, 3, 2)) - assert_array_equal(tractogram.header.to_world_space, np.diag([4, 3, 2, 1])) + # Use slicing + r_tractogram = tractogram[::-1] + assert_arrays_equal(r_tractogram.streamlines, self.streamlines[::-1]) - # Test that we can run __repr__ without error. - repr(tractogram.header) + assert_arrays_equal(r_tractogram.data_per_streamline['mean_curvature'], + self.mean_curvature[::-1]) + assert_arrays_equal(r_tractogram.data_per_streamline['mean_color'], + self.mean_color[::-1]) + assert_arrays_equal(r_tractogram.data_per_point['colors'], + self.colors[::-1]) class TestLazyTractogram(unittest.TestCase): @@ -553,27 +520,3 @@ def test_lazy_tractogram_len(self): # This should *not* produce a warning. assert_equal(len(tractogram), 1234) assert_equal(len(w), 0) - - def test_lazy_tractogram_header(self): - # Empty `LazyTractogram`, with default header - tractogram = LazyTractogram() - assert_true(tractogram.header.nb_streamlines is None) - assert_equal(tractogram.header.nb_scalars_per_point, 0) - assert_equal(tractogram.header.nb_properties_per_streamline, 0) - assert_array_equal(tractogram.header.voxel_sizes, (1, 1, 1)) - assert_array_equal(tractogram.header.to_world_space, np.eye(4)) - assert_equal(tractogram.header.extra, {}) - - streamlines = lambda: (x for x in self.streamlines) - scalars = lambda: (x for x in self.colors) - properties = lambda: (x for x in self.mean_curvature_torsion) - tractogram = LazyTractogram(streamlines) - header = tractogram.header - - assert_equal(header.nb_scalars_per_point, 0) - tractogram.scalars = scalars - assert_equal(header.nb_scalars_per_point, self.nb_scalars_per_point) - - assert_equal(header.nb_properties_per_streamline, 0) - tractogram.properties = properties - assert_equal(header.nb_properties_per_streamline, self.nb_properties_per_streamline) From 40bf350ff7a05a98e7a2dedc1033d63da673c4c8 Mon Sep 17 00:00:00 2001 From: Eleftherios Garyfallidis Date: Mon, 2 Nov 2015 18:14:31 -0500 Subject: [PATCH 19/54] Updated TractogramFile --- nibabel/streamlines/base_format.py | 93 ++++++++++-------------------- 1 file changed, 30 insertions(+), 63 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 52563a5a97..e1138420e3 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -499,55 +499,54 @@ def transform(self, affine): return super(LazyTractogram, self).transform(affine, lazy=True) +class abstractclassmethod(classmethod): + __isabstractmethod__ = True + + def __init__(self, callable): + callable.__isabstractmethod__ = True + super(abstractclassmethod, self).__init__(callable) + + class TractogramFile(object): - ''' Convenience class to encapsulate streamlines file format. ''' + ''' Convenience class to encapsulate tractogram file format. ''' __metaclass__ = ABCMeta - def __init__(self, tractogram): - self.tractogram = tractogram + def __init__(self, tractogram, header=None): + self._tractogram = tractogram + self._header = TractogramHeader() if header is None else header + + @property + def tractogram(self): + return self._tractogram @property def streamlines(self): return self.tractogram.streamlines @property - def scalars(self): - return self.tractogram.scalars + def header(self): + return self._header - @property - def properties(self): - return self.tractogram.properties + def get_tractogram(self): + return self.tractogram - @property - def header(self): - return self.tractogram.header + def get_header(self): + return self.header @classmethod def get_magic_number(cls): ''' Returns streamlines file's magic number. ''' raise NotImplementedError() - @classmethod - def can_save_scalars(cls): - ''' Tells if the streamlines format supports saving scalars. ''' - raise NotImplementedError() - - @classmethod - def can_save_properties(cls): - ''' Tells if the streamlines format supports saving properties. ''' - raise NotImplementedError() - @classmethod def is_correct_format(cls, fileobj): ''' Checks if the file has the right streamlines file format. - Parameters ---------- fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the beginning of the header). - Returns ------- is_correct_format : boolean @@ -555,65 +554,33 @@ def is_correct_format(cls, fileobj): ''' raise NotImplementedError() - @staticmethod - def load(fileobj, ref, lazy_load=True): + @abstractclassmethod + def load(cls, fileobj, lazy_load=True): ''' Loads streamlines from a file-like object. - Parameters ---------- fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the beginning of the header). - - ref : filename | `Nifti1Image` object | 2D array (4,4) - Reference space where streamlines live in `fileobj`. - - lazy_load : boolean + lazy_load : boolean (optional) Load streamlines in a lazy manner i.e. they will not be kept in memory. For postprocessing speed, turn off this option. - Returns ------- - streamlines : Tractogram object - Returns an object containing streamlines' data and header - information. See 'nibabel.Tractogram'. + tractogram_file : ``TractogramFile`` object + Returns an object containing tractogram data and header + information. ''' raise NotImplementedError() - @staticmethod - def save(streamlines, fileobj, ref=None): + @abstractmethod + def save(self, fileobj): ''' Saves streamlines to a file-like object. - Parameters ---------- - streamlines : Tractogram object - Object containing streamlines' data and header information. - See 'nibabel.Tractogram'. - fileobj : string or file-like object If string, a filename; otherwise an open file-like object opened and ready to write. - - ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) - Reference space where streamlines will live in `fileobj`. - ''' - raise NotImplementedError() - - @staticmethod - def pretty_print(streamlines): - ''' Gets a formatted string of the header of a streamlines file format. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the - beginning of the header). - - Returns - ------- - info : string - Header information relevant to the streamlines file format. ''' raise NotImplementedError() From c84b70c878f6de31cf43762805044810b6cda69c Mon Sep 17 00:00:00 2001 From: Eleftherios Garyfallidis Date: Mon, 2 Nov 2015 18:17:02 -0500 Subject: [PATCH 20/54] minor cleanup --- nibabel/streamlines/trk.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 1042e5f9af..61862f7a0e 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -345,16 +345,6 @@ def get_magic_number(cls): ''' Return TRK's magic number. ''' return cls.MAGIC_NUMBER - @classmethod - def can_save_scalars(cls): - ''' Tells if the streamlines format supports saving scalars. ''' - return True - - @classmethod - def can_save_properties(cls): - ''' Tells if the streamlines format supports saving properties. ''' - return True - @classmethod def is_correct_format(cls, fileobj): ''' Check if the file is in TRK format. From 19268ff2c352f92970b236c8f933c392aa137257 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 2 Nov 2015 23:06:44 -0500 Subject: [PATCH 21/54] Changed unit tests to reflect modifications made to Tractogram --- nibabel/streamlines/__init__.py | 2 +- nibabel/streamlines/base_format.py | 39 ++- nibabel/streamlines/header.py | 4 +- nibabel/streamlines/tests/test_base_format.py | 9 + nibabel/streamlines/tests/test_streamlines.py | 127 +++++---- nibabel/streamlines/tests/test_trk.py | 142 ++++++---- nibabel/streamlines/trk.py | 264 ++++++++++-------- nibabel/testing/__init__.py | 8 - 8 files changed, 344 insertions(+), 251 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index a58f72901c..cfd8990741 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -90,7 +90,7 @@ def load(fileobj, lazy_load=False, ref=None): if tractogram_file is None: raise TypeError("Unknown format for 'fileobj': {}".format(fileobj)) - return tractogram_file.load(fileobj, lazy_load=lazy_load, ref=ref) + return tractogram_file.load(fileobj, lazy_load=lazy_load) def save(tractogram_file, filename): diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index e1138420e3..9276f01f2c 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -38,13 +38,24 @@ def __init__(self, iterable=None): iterable : iterable (optional) If specified, create a ``CompactList`` object initialized from iterable's items. Otherwise, create an empty ``CompactList``. + + Notes + ----- + If `iterable` is a ``CompactList`` object, a view is returned and no + memory is allocated. For an actual copy use the `.copy()` method. """ # Create new empty `CompactList` object. self._data = None self._offsets = [] self._lengths = [] - if iterable is not None: + if isinstance(iterable, CompactList): + # Create a view. + self._data = iterable._data + self._offsets = iterable._offsets + self._lengths = iterable._lengths + + elif iterable is not None: # Initialize the `CompactList` object from iterable's item. BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. @@ -249,14 +260,8 @@ def __init__(self, streamlines=None, data_per_point=None): self.streamlines = streamlines - - self.data_per_streamline = {} - if data_per_streamline is not None: - self.data_per_streamline = data_per_streamline - - self.data_per_point = {} - if data_per_point is not None: - self.data_per_point = data_per_point + self.data_per_streamline = data_per_streamline + self.data_per_point = data_per_point @property def streamlines(self): @@ -272,6 +277,9 @@ def data_per_streamline(self): @data_per_streamline.setter def data_per_streamline(self, value): + if value is None: + value = {} + self._data_per_streamline = {} for k, v in value.items(): self._data_per_streamline[k] = np.asarray(v) @@ -282,6 +290,9 @@ def data_per_point(self): @data_per_point.setter def data_per_point(self, value): + if value is None: + value = {} + self._data_per_point = {} for k, v in value.items(): self._data_per_point[k] = CompactList(v) @@ -538,6 +549,16 @@ def get_magic_number(cls): ''' Returns streamlines file's magic number. ''' raise NotImplementedError() + @classmethod + def support_data_per_point(cls): + ''' Tells if this tractogram format supports saving data per point. ''' + raise NotImplementedError() + + @classmethod + def support_data_per_streamline(cls): + ''' Tells if this tractogram format supports saving data per streamline. ''' + raise NotImplementedError() + @classmethod def is_correct_format(cls, fileobj): ''' Checks if the file has the right streamlines file format. diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 706afd348a..2298e395b2 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -19,7 +19,7 @@ class Field: DIMENSIONS = "dimensions" MAGIC_NUMBER = "magic_number" ORIGIN = "origin" - to_world_space = "to_world_space" + VOXEL_TO_RASMM = "voxel_to_rasmm" VOXEL_ORDER = "voxel_order" ENDIAN = "endian" @@ -38,7 +38,7 @@ def to_world_space(self): @to_world_space.setter def to_world_space(self, value): - self._to_world_space = np.array(value, dtype=np.float32) + self._to_world_space = np.asarray(value, dtype=np.float32) @property def voxel_sizes(self): diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index 5ae8c0d5d8..53d8589ab7 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -293,6 +293,15 @@ def test_tractogram_creation(self): assert_raises(ValueError, Tractogram, self.streamlines, data_per_point=data_per_point) + # Inconsistent number of scalars between streamlines + wrong_data = [[(1, 0, 0)]*1, + [(0, 1)]*2, + [(0, 0, 1)]*5] + + data_per_point = {'wrong_data': wrong_data} + assert_raises(ValueError, Tractogram, self.streamlines, + data_per_point=data_per_point) + def test_tractogram_getter(self): # Tractogram with only streamlines tractogram = Tractogram(streamlines=self.streamlines) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index e78be90429..5dabad1815 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -20,19 +20,20 @@ DATA_PATH = pjoin(os.path.dirname(__file__), 'data') -def check_tractogram(tractogram, nb_streamlines, streamlines, scalars, properties): +def check_tractogram(tractogram, nb_streamlines, streamlines, data_per_streamline, data_per_point): # Check data assert_equal(len(tractogram), nb_streamlines) assert_arrays_equal(tractogram.streamlines, streamlines) - assert_arrays_equal(tractogram.scalars, scalars) - assert_arrays_equal(tractogram.properties, properties) - assert_true(isiterable(tractogram)) - assert_equal(tractogram.header.nb_streamlines, nb_streamlines) - nb_scalars_per_point = 0 if len(scalars) == 0 else len(scalars[0][0]) - nb_properties_per_streamline = 0 if len(properties) == 0 else len(properties[0]) - assert_equal(tractogram.header.nb_scalars_per_point, nb_scalars_per_point) - assert_equal(tractogram.header.nb_properties_per_streamline, nb_properties_per_streamline) + for key in data_per_streamline.keys(): + assert_arrays_equal(tractogram.data_per_streamline[key], + data_per_streamline[key]) + + for key in data_per_point.keys(): + assert_arrays_equal(tractogram.data_per_point[key], + data_per_point[key]) + + assert_true(isiterable(tractogram)) def test_is_supported(): @@ -127,6 +128,9 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] + self.data_per_point = {'scalars': self.colors} + self.data_per_streamline = {'properties': self.mean_curvature_torsion} + self.nb_streamlines = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) @@ -135,91 +139,94 @@ def setUp(self): def test_load_empty_file(self): for empty_filename in self.empty_filenames: tractogram_file = nib.streamlines.load(empty_filename, - lazy_load=False, - ref=self.to_world_space) + lazy_load=False) assert_true(isinstance(tractogram_file, TractogramFile)) assert_true(type(tractogram_file.tractogram), Tractogram) - check_tractogram(tractogram_file.tractogram, 0, [], [], []) + check_tractogram(tractogram_file.tractogram, 0, [], {}, {}) def test_load_simple_file(self): for simple_filename in self.simple_filenames: tractogram_file = nib.streamlines.load(simple_filename, - lazy_load=False, - ref=self.to_world_space) + lazy_load=False) assert_true(isinstance(tractogram_file, TractogramFile)) assert_true(type(tractogram_file.tractogram), Tractogram) check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, [], []) + self.streamlines, {}, {}) - # Test lazy_load - tractogram_file = nib.streamlines.load(simple_filename, - lazy_load=True, - ref=self.to_world_space) + # # Test lazy_load + # tractogram_file = nib.streamlines.load(simple_filename, + # lazy_load=True) - assert_true(isinstance(tractogram_file, TractogramFile)) - assert_true(type(tractogram_file.tractogram), LazyTractogram) - check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, [], []) + # assert_true(isinstance(tractogram_file, TractogramFile)) + # assert_true(type(tractogram_file.tractogram), LazyTractogram) + # check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + # self.streamlines, {}, {}) def test_load_complex_file(self): for complex_filename in self.complex_filenames: file_format = nib.streamlines.detect_format(complex_filename) - scalars = [] - if file_format.can_save_scalars(): - scalars = self.colors + data_per_point = {} + if file_format.support_data_per_point(): + data_per_point = {'scalars': self.colors} - properties = [] - if file_format.can_save_properties(): - properties = self.mean_curvature_torsion + data_per_streamline = [] + if file_format.support_data_per_streamline(): + data_per_streamline = {'properties': self.mean_curvature_torsion} tractogram_file = nib.streamlines.load(complex_filename, - lazy_load=False, - ref=self.to_world_space) + lazy_load=False) assert_true(isinstance(tractogram_file, TractogramFile)) assert_true(type(tractogram_file.tractogram), Tractogram) check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, scalars, properties) - - # Test lazy_load - tractogram_file = nib.streamlines.load(complex_filename, - lazy_load=True, - ref=self.to_world_space) - assert_true(isinstance(tractogram_file, TractogramFile)) - assert_true(type(tractogram_file.tractogram), LazyTractogram) - check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, scalars, properties) + self.streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + + # # Test lazy_load + # tractogram_file = nib.streamlines.load(complex_filename, + # lazy_load=True) + # assert_true(isinstance(tractogram_file, TractogramFile)) + # assert_true(type(tractogram_file.tractogram), LazyTractogram) + # check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + # self.streamlines, + # data_per_streamline=data_per_streamline, + # data_per_point=data_per_point) def test_save_simple_file(self): tractogram = Tractogram(self.streamlines) - for ext in nib.streamlines.FORMATS.keys(): + for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: - nib.streamlines.save(tractogram, f.name) - loaded_tractogram = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) - check_tractogram(loaded_tractogram, self.nb_streamlines, - self.streamlines, [], []) + nib.streamlines.save_tractogram(tractogram, f.name) + tractogram_file = nib.streamlines.load(f, lazy_load=False) + check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + self.streamlines, {}, {}) def test_save_complex_file(self): - tractogram = Tractogram(self.streamlines, scalars=self.colors, properties=self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, + data_per_streamline=self.data_per_streamline, + data_per_point=self.data_per_point) for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: with clear_and_catch_warnings(record=True, modules=[trk]) as w: - nib.streamlines.save(tractogram, f.name) + nib.streamlines.save_tractogram(tractogram, f.name) - # If streamlines format does not support saving scalars or - # properties, a warning message should be issued. - if not (cls.can_save_scalars() and cls.can_save_properties()): + # If streamlines format does not support saving data per point + # or data per streamline, a warning message should be issued. + if not (cls.support_data_per_point() and cls.support_data_per_streamline()): assert_equal(len(w), 1) assert_true(issubclass(w[0].category, UsageWarning)) - scalars = [] - if cls.can_save_scalars(): - scalars = self.colors + data_per_point = {} + if cls.support_data_per_point(): + data_per_point = self.data_per_point - properties = [] - if cls.can_save_properties(): - properties = self.mean_curvature_torsion + data_per_streamline = [] + if cls.support_data_per_streamline(): + data_per_streamline = self.data_per_streamline - loaded_tractogram = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) - check_tractogram(loaded_tractogram, self.nb_streamlines, - self.streamlines, scalars, properties) + tractogram_file = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) + check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + self.streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index e1584731db..1b1c2ab7a6 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -5,7 +5,7 @@ from nibabel.externals.six import BytesIO from nibabel.testing import suppress_warnings, clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal, assert_tractogram_equal, isiterable +from nibabel.testing import assert_arrays_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true from .. import base_format @@ -18,19 +18,36 @@ DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') -def check_tractogram(tractogram, nb_streamlines, streamlines, scalars, properties): +def assert_tractogram_equal(t1, t2): + assert_equal(len(t1), len(t2)) + assert_arrays_equal(t1.streamlines, t2.streamlines) + + assert_equal(len(t1.data_per_streamline), len(t2.data_per_streamline)) + for key in t1.data_per_streamline.keys(): + assert_arrays_equal(t1.data_per_streamline[key], + t2.data_per_streamline[key]) + + assert_equal(len(t1.data_per_point), len(t2.data_per_point)) + for key in t1.data_per_point.keys(): + assert_arrays_equal(t1.data_per_point[key], + t2.data_per_point[key]) + + + +def check_tractogram(tractogram, nb_streamlines, streamlines, data_per_streamline, data_per_point): # Check data assert_equal(len(tractogram), nb_streamlines) assert_arrays_equal(tractogram.streamlines, streamlines) - assert_arrays_equal(tractogram.scalars, scalars) - assert_arrays_equal(tractogram.properties, properties) - assert_true(isiterable(tractogram)) - assert_equal(tractogram.header.nb_streamlines, nb_streamlines) - nb_scalars_per_point = 0 if len(scalars) == 0 else len(scalars[0][0]) - nb_properties_per_streamline = 0 if len(properties) == 0 else len(properties[0]) - assert_equal(tractogram.header.nb_scalars_per_point, nb_scalars_per_point) - assert_equal(tractogram.header.nb_properties_per_streamline, nb_properties_per_streamline) + for key in data_per_streamline.keys(): + assert_arrays_equal(tractogram.data_per_streamline[key], + data_per_streamline[key]) + + for key in data_per_point.keys(): + assert_arrays_equal(tractogram.data_per_point[key], + data_per_point[key]) + + assert_true(isiterable(tractogram)) class TestTRK(unittest.TestCase): @@ -54,35 +71,40 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] + self.data_per_point = {'scalars': self.colors} + self.data_per_streamline = {'properties': self.mean_curvature_torsion} + self.nb_streamlines = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) self.affine = np.eye(4) def test_load_empty_file(self): - trk = TrkFile.load(self.empty_trk_filename, ref=None, lazy_load=False) - check_tractogram(trk.tractogram, 0, [], [], []) + trk = TrkFile.load(self.empty_trk_filename, lazy_load=False) + check_tractogram(trk.tractogram, 0, [], {}, {}) - trk = TrkFile.load(self.empty_trk_filename, ref=None, lazy_load=True) - # Suppress warning about loading a TRK file in lazy mode with count=0. - with suppress_warnings(): - check_tractogram(trk.tractogram, 0, [], [], []) + # trk = TrkFile.load(self.empty_trk_filename, lazy_load=True) + # # Suppress warning about loading a TRK file in lazy mode with count=0. + # with suppress_warnings(): + # check_tractogram(trk.tractogram, 0, [], [], []) def test_load_simple_file(self): - trk = TrkFile.load(self.simple_trk_filename, ref=None, lazy_load=False) - check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) + trk = TrkFile.load(self.simple_trk_filename, lazy_load=False) + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, {}, {}) - trk = TrkFile.load(self.simple_trk_filename, ref=None, lazy_load=True) - check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) + # trk = TrkFile.load(self.simple_trk_filename, lazy_load=True) + # check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) def test_load_complex_file(self): - trk = TrkFile.load(self.complex_trk_filename, ref=None, lazy_load=False) + trk = TrkFile.load(self.complex_trk_filename, lazy_load=False) check_tractogram(trk.tractogram, self.nb_streamlines, - self.streamlines, self.colors, self.mean_curvature_torsion) + self.streamlines, + data_per_point=self.data_per_point, + data_per_streamline=self.data_per_streamline) - trk = TrkFile.load(self.complex_trk_filename, ref=None, lazy_load=True) - check_tractogram(trk.tractogram, self.nb_streamlines, - self.streamlines, self.colors, self.mean_curvature_torsion) + # trk = TrkFile.load(self.complex_trk_filename, lazy_load=True) + # check_tractogram(trk.tractogram, self.nb_streamlines, + # self.streamlines, self.colors, self.mean_curvature_torsion) def test_load_file_with_wrong_information(self): trk_file = open(self.simple_trk_filename, 'rb').read() @@ -91,19 +113,19 @@ def test_load_file_with_wrong_information(self): count = np.array(0, dtype="int32").tostring() new_trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) - check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, {}, {}) - trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) - with clear_and_catch_warnings(record=True, modules=[base_format]) as w: - check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) - assert_equal(len(w), 1) - assert_true(issubclass(w[0].category, UsageWarning)) + # trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) + # with clear_and_catch_warnings(record=True, modules=[base_format]) as w: + # check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, {}, {}) + # assert_equal(len(w), 1) + # assert_true(issubclass(w[0].category, UsageWarning)) # Simulate a TRK file where `voxel_order` was not provided. voxel_order = np.zeros(1, dtype="|S3").tostring() new_trk_file = trk_file[:948] + voxel_order + trk_file[948+3:] with clear_and_catch_warnings(record=True, modules=[trk]) as w: - TrkFile.load(BytesIO(new_trk_file), ref=None) + TrkFile.load(BytesIO(new_trk_file)) assert_equal(len(w), 1) assert_true(issubclass(w[0].category, HeaderWarning)) assert_true("LPS" in str(w[0].message)) @@ -128,7 +150,7 @@ def test_write_simple_file(self): loaded_trk = TrkFile.load(trk_file) check_tractogram(loaded_trk.tractogram, self.nb_streamlines, - self.streamlines, [], []) + self.streamlines, {}, {}) loaded_trk_orig = TrkFile.load(self.simple_trk_filename) assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) @@ -138,7 +160,8 @@ def test_write_simple_file(self): def test_write_complex_file(self): # With scalars - tractogram = Tractogram(self.streamlines, scalars=self.colors) + tractogram = Tractogram(self.streamlines, + data_per_point=self.data_per_point) trk_file = BytesIO() trk = TrkFile(tractogram, ref=self.affine) @@ -146,23 +169,28 @@ def test_write_complex_file(self): trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file, lazy_load=False) - check_tractogram(loaded_trk.tractogram, self.nb_streamlines, - self.streamlines, self.colors, []) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, self.streamlines, + data_per_streamline={}, + data_per_point=self.data_per_point) # With properties - tractogram = Tractogram(self.streamlines, properties=self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, + data_per_streamline=self.data_per_streamline) - trk_file = BytesIO() trk = TrkFile(tractogram, ref=self.affine) + trk_file = BytesIO() trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file, lazy_load=False) - check_tractogram(loaded_trk.tractogram, self.nb_streamlines, - self.streamlines, [], self.mean_curvature_torsion) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, self.streamlines, + data_per_streamline=self.data_per_streamline, + data_per_point={}) # With scalars and properties - tractogram = Tractogram(self.streamlines, scalars=self.colors, properties=self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, + data_per_point=self.data_per_point, + data_per_streamline=self.data_per_streamline) trk_file = BytesIO() trk = TrkFile(tractogram, ref=self.affine) @@ -170,8 +198,9 @@ def test_write_complex_file(self): trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file, lazy_load=False) - check_tractogram(loaded_trk.tractogram, self.nb_streamlines, - self.streamlines, self.colors, self.mean_curvature_torsion) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, self.streamlines, + data_per_streamline=self.data_per_streamline, + data_per_point=self.data_per_point) loaded_trk_orig = TrkFile.load(self.complex_trk_filename) assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) @@ -185,7 +214,8 @@ def test_write_erroneous_file(self): [(0, 1, 0)], [(0, 0, 1)]] - tractogram = Tractogram(self.streamlines, scalars) + tractogram = Tractogram(self.streamlines, + data_per_point={'scalars': scalars}) trk = TrkFile(tractogram, ref=self.affine) assert_raises(DataError, trk.save, BytesIO()) @@ -193,9 +223,10 @@ def test_write_erroneous_file(self): scalars = [[(1, 0, 0)]*1, [(0, 1, 0)]*2] - tractogram = Tractogram(self.streamlines, scalars) + tractogram = Tractogram(self.streamlines, + data_per_point={'scalars': scalars}) trk = TrkFile(tractogram, ref=self.affine) - assert_raises(DataError, trk.save, BytesIO()) + assert_raises(IndexError, trk.save, BytesIO()) # # Unit test moved to test_base_format.py # # Inconsistent number of scalars between points @@ -206,31 +237,22 @@ def test_write_erroneous_file(self): # tractogram = Tractogram(self.streamlines, scalars) # assert_raises(ValueError, TrkFile.save, tractogram, BytesIO()) - # # Unit test moved to test_base_format.py - # # Inconsistent number of scalars between streamlines - # scalars = [[(1, 0, 0)]*1, - # [(0, 1)]*2, - # [(0, 0, 1)]*5] - - # tractogram = Tractogram(self.streamlines, scalars) - # assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) - - # Unit test moved to test_base_format.py # Inconsistent number of properties properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - tractogram = Tractogram(self.streamlines, properties=properties) + tractogram = Tractogram(self.streamlines, + data_per_streamline={'properties': properties}) trk = TrkFile(tractogram, ref=self.affine) assert_raises(DataError, trk.save, BytesIO()) - # Unit test moved to test_base_format.py # No properties for every streamlines properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11, 2.22], dtype="f4")] - tractogram = Tractogram(self.streamlines, properties=properties) + tractogram = Tractogram(self.streamlines, + data_per_streamline={'properties': properties}) trk = TrkFile(tractogram, ref=self.affine) - assert_raises(DataError, trk.save, BytesIO()) + assert_raises(IndexError, trk.save, BytesIO()) # def test_write_file_lazy_tractogram(self): # streamlines = lambda: (point for point in self.streamlines) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 61862f7a0e..b93a65f5d9 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -59,7 +59,7 @@ ('scalar_name', 'S20', 10), (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), ('property_name', 'S20', 10), - (Field.to_world_space, 'f4', (4, 4)), # new field for version 2 + (Field.VOXEL_TO_RASMM, 'f4', (4, 4)), # new field for version 2 ('reserved', 'S444'), (Field.VOXEL_ORDER, 'S4'), ('pad2', 'S4'), @@ -140,16 +140,6 @@ def __init__(self, fileobj): # Keep the file position where the data begin. self.offset_data = f.tell() - # if f.name is not None and self.header[Field.NB_STREAMLINES] > 0: - # filesize = os.path.getsize(f.name) - self.offset_data - # # Remove properties - # filesize -= self.header[Field.NB_STREAMLINES] * self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * 4. - # # Remove the points count at the beginning of each streamline. - # filesize -= self.header[Field.NB_STREAMLINES] * 4. - # # Get nb points. - # nb_points = filesize / ((3 + self.header[Field.NB_SCALARS_PER_POINT]) * 4.) - # self.header[Field.NB_POINTS] = int(nb_points) - def __iter__(self): i4_dtype = np.dtype(self.endianness + "i4") f4_dtype = np.dtype(self.endianness + "f4") @@ -214,11 +204,11 @@ def create_empty_header(cls): header[Field.MAGIC_NUMBER] = TrkFile.MAGIC_NUMBER header[Field.VOXEL_SIZES] = (1, 1, 1) header[Field.DIMENSIONS] = (1, 1, 1) - header[Field.to_world_space] = np.eye(4) + header[Field.VOXEL_TO_RASMM] = np.eye(4) header['version'] = 2 header['hdr_size'] = TrkFile.HEADER_SIZE - return header + return header[0] def __init__(self, fileobj, header): self.header = self.create_empty_header() @@ -235,7 +225,7 @@ def __init__(self, fileobj, header): self.header[Field.NB_SCALARS_PER_POINT] = header.nb_scalars_per_point self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = header.nb_properties_per_streamline self.header[Field.VOXEL_SIZES] = header.voxel_sizes - self.header[Field.to_world_space] = header.to_world_space + self.header[Field.VOXEL_TO_RASMM] = header.to_world_space self.header[Field.VOXEL_ORDER] = header.voxel_order # Keep counts for correcting incoherent fields or warn. @@ -248,27 +238,34 @@ def __init__(self, fileobj, header): self.file = Opener(fileobj, mode="wb") # Keep track of the beginning of the header. self.beginning = self.file.tell() - self.file.write(self.header[0].tostring()) + self.file.write(self.header.tostring()) def write(self, tractogram): i4_dtype = np.dtype("i4") f4_dtype = np.dtype("f4") + # TRK's streamlines need to be in 'voxelmm' space and by definition + # tractogram streamlines are in RAS+ and mm space. + affine = np.linalg.inv(self.header[Field.VOXEL_TO_RASMM]) + affine[range(3), range(3)] *= self.header[Field.VOXEL_SIZES] + + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas streamlines passed in parameters assume (0,0,0) + # to be the center of the voxel. Thus, streamlines are shifted of + # half a voxel. + affine[:-1, -1] += np.array(self.header[Field.VOXEL_SIZES])/2. + + tractogram.apply_affine(affine) + for t in tractogram: - if len(t.scalars) > 0 and len(t.scalars) != len(t.streamline): + if any((len(d) != len(t.streamline) for d in t.data_for_points.values())): raise DataError("Missing scalars for some points!") points = np.asarray(t.streamline, dtype=f4_dtype) - scalars = np.asarray(t.scalars, dtype=f4_dtype).reshape((len(points), -1)) - properties = np.asarray(t.properties, dtype=f4_dtype) - - # TRK's streamlines need to be in 'voxelmm' space - points = points * self.header[Field.VOXEL_SIZES] - # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas streamlines passed in parameters assume (0,0,0) - # to be the center of the voxel. Thus, streamlines are shifted of - # half a voxel. - points += np.array(self.header[Field.VOXEL_SIZES])/2. + keys = sorted(t.data_for_points.keys()) + scalars = np.asarray([t.data_for_points[k] for k in keys], dtype=f4_dtype).reshape((len(points), -1)) + keys = sorted(t.data_for_streamline.keys()) + properties = np.asarray([t.data_for_streamline[k] for k in keys], dtype=f4_dtype).flatten() data = struct.pack(i4_dtype.str[:-1], len(points)) data += np.concatenate((points, scalars), axis=1).tostring() @@ -298,7 +295,91 @@ def write(self, tractogram): # Overwrite header with updated one. self.file.seek(self.beginning, os.SEEK_SET) - self.file.write(self.header[0].tostring()) + self.file.write(self.header.tostring()) + + +import itertools +from nibabel.streamlines.base_format import CompactList + +def create_compactlist_from_generator(gen): + BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. + + streamlines = CompactList() + scalars = CompactList() + properties = np.array([]) + + gen = iter(gen) + try: + first_element = next(gen) + gen = itertools.chain([first_element], gen) + except StopIteration: + return streamlines, scalars, properties + + # Allocated some buffer memory. + pts = np.asarray(first_element[0]) + scals = np.asarray(first_element[1]) + props = np.asarray(first_element[2]) + + scals_shape = scals.shape + props_shape = props.shape + + streamlines._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) + scalars._data = np.empty((BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) + properties = np.empty((BUFFER_SIZE, props.shape[0]), dtype=props.dtype) + + offset = 0 + for i, (pts, scals, props) in enumerate(gen): + pts = np.asarray(pts) + scals = np.asarray(scals) + props = np.asarray(props) + + if scals.shape[1] != scals_shape[1]: + raise ValueError("Number of scalars differs from one" + " point or streamline to another") + + if props.shape != props_shape: + raise ValueError("Number of properties differs from one" + " streamline to another") + + end = offset + len(pts) + if end >= len(streamlines._data): + # Resize is needed (at least `len(pts)` items will be added). + streamlines._data.resize((len(streamlines._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) + scalars._data.resize((len(scalars._data) + len(scals)+BUFFER_SIZE, scals.shape[1])) + + streamlines._offsets.append(offset) + streamlines._lengths.append(len(pts)) + streamlines._data[offset:offset+len(pts)] = pts + scalars._data[offset:offset+len(scals)] = scals + + offset += len(pts) + + if i >= len(properties): + properties.resize((len(properties) + BUFFER_SIZE, props.shape[0])) + + properties[i] = props + + # Clear unused memory. + streamlines._data.resize((offset, pts.shape[1])) + + if scals_shape[1] == 0: + # Because resizing an empty ndarray creates memory! + scalars._data = np.empty((offset, scals.shape[1])) + else: + scalars._data.resize((offset, scals.shape[1])) + + # Share offsets and lengths between streamlines and scalars. + scalars._offsets = streamlines._offsets + scalars._lengths = streamlines._lengths + + if props_shape[0] == 0: + # Because resizing an empty ndarray creates memory! + properties = np.empty((i+1, props.shape[0])) + else: + properties.resize((i+1, props.shape[0])) + + return streamlines, scalars, properties + class TrkFile(TractogramFile): @@ -319,19 +400,19 @@ class TrkFile(TractogramFile): MAGIC_NUMBER = b"TRACK" HEADER_SIZE = 1000 - def __init__(self, tractogram, ref, header=None): + def __init__(self, tractogram, header=None, ref=np.eye(4)): """ Parameters ---------- tractogram : ``Tractogram`` object Tractogram that will be contained in this ``TrkFile``. - ref : filename | `Nifti1Image` object | 2D array (4,4) - Reference space where streamlines live in. - - header : ``TractogramHeader`` file + header : ``TractogramHeader`` file (optional) Metadata associated to this tractogram file. + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines live in. + Notes ----- Streamlines of the tractogram are assumed to be in *RAS+* and *mm* space @@ -345,6 +426,16 @@ def get_magic_number(cls): ''' Return TRK's magic number. ''' return cls.MAGIC_NUMBER + @classmethod + def support_data_per_point(cls): + ''' Tells if this tractogram format supports saving data per point. ''' + return True + + @classmethod + def support_data_per_streamline(cls): + ''' Tells if this tractogram format supports saving data per streamline. ''' + return True + @classmethod def is_correct_format(cls, fileobj): ''' Check if the file is in TRK format. @@ -369,7 +460,7 @@ def is_correct_format(cls, fileobj): return False @classmethod - def load(cls, fileobj, lazy_load=False, ref=None): + def load(cls, fileobj, lazy_load=False): ''' Loads streamlines from a file-like object. Parameters @@ -383,9 +474,6 @@ def load(cls, fileobj, lazy_load=False, ref=None): Load streamlines in a lazy manner i.e. they will not be kept in memory. - ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) - Reference space where streamlines live in `fileobj`. - Returns ------- trk_file : ``TrkFile`` object @@ -407,7 +495,7 @@ def load(cls, fileobj, lazy_load=False, ref=None): # If voxel order implied from the affine does not match the voxel # order save in the TRK header, change the orientation. header_ornt = trk_reader.header[Field.VOXEL_ORDER] - affine_ornt = "".join(nib.orientations.aff2axcodes(trk_reader.header[Field.to_world_space])) + affine_ornt = "".join(nib.orientations.aff2axcodes(trk_reader.header[Field.VOXEL_TO_RASMM])) header_ornt = nib.orientations.axcodes2ornt(header_ornt) affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) ornt = nib.orientations.ornt_transform(header_ornt, affine_ornt) @@ -415,7 +503,7 @@ def load(cls, fileobj, lazy_load=False, ref=None): affine = np.dot(M, affine) # Applied the affine going from voxel space to rasmm. - affine = np.dot(trk_reader.header[Field.to_world_space], affine) + affine = np.dot(trk_reader.header[Field.VOXEL_TO_RASMM], affine) # TrackVis considers coordinate (0,0,0) to be the corner of the # voxel whereas streamlines returned assume (0,0,0) to be the @@ -424,6 +512,7 @@ def load(cls, fileobj, lazy_load=False, ref=None): affine[:-1, -1] -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. if lazy_load: + # TODO when LazyTractogram has been refactored. def _apply_transform(trk_reader): for pts, scals, props in trk_reader: # TRK's streamlines are in 'voxelmm' space, we send them to voxel space. @@ -445,73 +534,32 @@ def _apply_transform(trk_reader): if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: tractogram.properties = lambda: [] - # elif Field.NB_POINTS in trk_reader.header: - # # 'count' field is provided, we can avoid creating list of numpy - # # arrays (more memory efficient). - - # nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] - # nb_points = trk_reader.header[Field.NB_POINTS] - - # points = CompactList() - # points._data = np.empty((nb_points, 3), dtype=np.float32) - - # scalars = CompactList() - # scalars._data = np.empty((nb_points, trk_reader.header[Field.NB_SCALARS_PER_POINT]), - # dtype=np.float32) - - # properties = CompactList() - # properties._data = np.empty((nb_streamlines, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]), - # dtype=np.float32) - - # offset = 0 - # offsets = [] - # lengths = [] - # for i, (pts, scals, props) in enumerate(trk_reader): - # offsets.append(offset) - # lengths.append(len(pts)) - # points._data[offset:offset+len(pts)] = pts - # scalars._data[offset:offset+len(scals)] = scals - # properties._data[i] = props - # offset += len(pts) - - # points.offsets = offsets - # scalars.offsets = offsets - # points.lengths = lengths - # scalars.lengths = lengths - - # streamlines = Tractogram(points, scalars, properties) - # streamlines.apply_affine(affine) - - # # Overwrite scalars and properties if there is none - # if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: - # streamlines.scalars = [] - # if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: - # streamlines.properties = [] - else: - tractogram = Tractogram.create_from_generator(trk_reader) - #tractogram = Tractogram(*zip(*trk_reader)) - tractogram.apply_affine(affine) - - # Overwrite scalars and properties if there is none - if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: - tractogram.scalars = [] - if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: - tractogram.properties = [] - - # Set available common information about streamlines in the header - tractogram.header.to_world_space = affine - - # If 'count' field is 0, i.e. not provided, we don't set `nb_streamlines` - if trk_reader.header[Field.NB_STREAMLINES] > 0: - tractogram.header.nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] - - # Keep extra information about TRK format - tractogram.header.extra = trk_reader.header + streamlines, scalars, properties = create_compactlist_from_generator(trk_reader) + tractogram = Tractogram(streamlines) + + if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: + if len(trk_reader.header['scalar_name'][0]) > 0: + for i in range(trk_reader.header[Field.NB_SCALARS_PER_POINT]): + clist = CompactList() + clist._data = scalars._data[:, i] + clist._offsets = scalars._offsets + clist._lengths = scalars._lengths + tractogram.data_per_point[trk_reader.header['scalar_name'][i]] = clist + else: + tractogram.data_per_point['scalars'] = scalars + + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + if len(trk_reader.header['property_name'][0]) > 0: + for i in range(trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]): + tractogram.data_per_streamline[trk_reader.header['property_name'][i]] = properties[:, i] + else: + tractogram.data_per_streamline['properties'] = properties + + # Bring tractogram to RAS+ and mm space + tractogram.apply_affine(affine) ## Perform some integrity checks - #if trk_reader.header[Field.VOXEL_ORDER] != tractogram.header.voxel_order: - # raise HeaderError("'voxel_order' does not match the affine.") #if tractogram.header.voxel_sizes != trk_reader.header[Field.VOXEL_SIZES]: # raise HeaderError("'voxel_sizes' does not match the affine.") #if tractogram.header.nb_scalars_per_point != trk_reader.header[Field.NB_SCALARS_PER_POINT]: @@ -532,14 +580,8 @@ def save(self, fileobj): of the TRK header data). ''' # Update header using the tractogram. - self.header.nb_scalars_per_point = 0 - if self.tractogram.scalars.shape is not None: - self.header.nb_scalars_per_point = len(self.tractogram.scalars[0]) - - self.header.nb_properties_per_streamline = 0 - if self.tractogram.properties.shape is not None: - self.header.nb_properties_per_streamline = len(self.tractogram.properties[0]) - + self.header.nb_scalars_per_point = sum(map(lambda e: len(e[0]), self.tractogram.data_per_point.values())) + self.header.nb_properties_per_streamline = sum(map(lambda e: len(e[0]), self.tractogram.data_per_streamline.values())) trk_writer = TrkWriter(fileobj, self.header) trk_writer.write(self.tractogram) @@ -572,7 +614,7 @@ def pretty_print(fileobj): info += "scalar_name:\n {0}".format("\n".join(hdr['scalar_name'])) info += "nb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) info += "property_name:\n {0}".format("\n".join(hdr['property_name'])) - info += "vox_to_world: {0}".format(hdr[Field.to_world_space]) + info += "vox_to_world: {0}".format(hdr[Field.VOXEL_TO_RASMM]) info += "voxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) info += "image_orientation_patient: {0}".format(hdr['image_orientation_patient']) info += "pad1: {0}".format(hdr['pad1']) diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index 935ae4ad00..d29aea2ca3 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -74,14 +74,6 @@ def assert_arrays_equal(arrays1, arrays2): assert_array_equal(arr1, arr2) -def assert_tractogram_equal(s1, s2): - assert_equal(s1.header, s2.header) - assert_equal(len(s1), len(s2)) - assert_arrays_equal(s1.streamlines, s2.streamlines) - assert_arrays_equal(s1.scalars, s2.scalars) - assert_arrays_equal(s1.properties, s2.properties) - - def get_fresh_mod(mod_name=__name__): # Get this module, with warning registry empty my_mod = sys.modules[mod_name] From 66a0ce761387e1d7ef5f2c4750b242272a4fe306 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 3 Nov 2015 08:38:16 -0500 Subject: [PATCH 22/54] Refactored LazyTractogram --- nibabel/streamlines/base_format.py | 126 +++++++++++++++++++---------- 1 file changed, 84 insertions(+), 42 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 9276f01f2c..5d5cd0e6e0 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -322,7 +322,6 @@ def __len__(self): def copy(self): """ Returns a copy of this `Tractogram` object. """ - new_data_per_streamline = {} for key in self.data_per_streamline: new_data_per_streamline[key] = self.data_per_streamline[key].copy() @@ -390,10 +389,12 @@ class LazyTractogram(Tractogram): If provided, ``scalars`` and ``properties`` must yield the same number of values as ``streamlines``. ''' - def __init__(self, streamlines_func=lambda:[], scalars_func=lambda: [], properties_func=lambda: [], getitem_func=None): - super(LazyTractogram, self).__init__(streamlines_func, scalars_func, properties_func) - self._data = lambda: zip_longest(self.streamlines, self.scalars, self.properties, fillvalue=[]) - self._getitem = getitem_func + def __init__(self, streamlines=lambda:[], data_per_streamline=None, data_per_point=None): + super(LazyTractogram, self).__init__(streamlines, data_per_streamline, data_per_point) + self.nb_streamlines = None + self._data = None + self._getitem = None + self._affine_to_apply = np.eye(4) @classmethod def create_from_data(cls, data_func): @@ -422,6 +423,13 @@ def create_from_data(cls, data_func): @property def streamlines(self): + if not np.all(self._affine_to_apply == np.eye(4)): + def _transform(): + for s in self._streamlines(): + yield apply_affine(self._affine_to_apply, s) + + return _transform() + return self._streamlines() @streamlines.setter @@ -432,34 +440,70 @@ def streamlines(self, value): self._streamlines = value @property - def scalars(self): - return self._scalars() + def data_per_streamline(self): + return self._data_per_streamline - @scalars.setter - def scalars(self, value): - if not callable(value): - raise TypeError("`scalars` must be a coroutine.") + @data_per_streamline.setter + def data_per_streamline(self, value): + if value is None: + value = {} - self._scalars = value - self.header.nb_scalars_per_point = 0 - scalars = pop(self.scalars) - if scalars is not None and len(scalars) > 0: - self.header.nb_scalars_per_point = len(scalars[0]) + self._data_per_streamline = {} + for k, v in value.items(): + if not callable(v): + raise TypeError("`data_per_streamline` must be a dict of coroutines.") + + self._data_per_streamline[k] = v @property - def properties(self): - return self._properties() + def data_per_point(self): + return self._data_per_point - @properties.setter - def properties(self, value): + @data_per_point.setter + def data_per_point(self, value): + if value is None: + value = {} + + self._data_per_point = {} + for k, v in value.items(): + if not callable(v): + raise TypeError("`data_per_point` must be a dict of coroutines.") + + self._data_per_point[k] = v + + @property + def data(self): + if self._data is not None: + return self._data() + + def _gen_data(): + data_per_streamline_generators = {} + for k, v in self.data_per_streamline.items(): + data_per_streamline_generators[k] = iter(v()) + + data_per_point_generators = {} + for k, v in self.data_per_point.items(): + data_per_point_generators[k] = iter(v()) + + for s in self.streamlines: + data_for_streamline = {} + for k, v in data_per_streamline_generators.items(): + data_for_streamline[k] = next(v) + + data_for_points = {} + for k, v in data_per_point_generators.items(): + data_for_points[k] = v() + + yield TractogramItem(s, data_for_streamline, data_for_points) + + return _gen_data() + + @data.setter + def data(self, value): if not callable(value): - raise TypeError("`properties` must be a coroutine.") + raise TypeError("`data` must be a coroutine.") - self._properties = value - self.header.nb_properties_per_streamline = 0 - properties = pop(self.properties) - if properties is not None: - self.header.nb_properties_per_streamline = len(properties) + self._data = value def __getitem__(self, idx): if self._getitem is None: @@ -469,45 +513,43 @@ def __getitem__(self, idx): def __iter__(self): i = 0 - for i, s in enumerate(self._data(), start=1): - yield TractogramItem(*s) + for i, tractogram_item in enumerate(self.data, start=1): + yield tractogram_item # To be safe, update information about number of streamlines. - self.header.nb_streamlines = i + self.nb_streamlines = i def __len__(self): # If length is unknown, we obtain it by iterating through streamlines. - if self.header.nb_streamlines is None: + if self.nb_streamlines is None: warn("Number of streamlines will be determined manually by looping" " through the streamlines. If you know the actual number of" " streamlines, you might want to set it beforehand via" " `self.header.nb_streamlines`." " Note this will consume any generators used to create this" " `LazyTractogram` object.", UsageWarning) - return sum(1 for _ in self) + return sum(1 for _ in self.streamlines) - return self.header.nb_streamlines + return self.nb_streamlines def copy(self): """ Returns a copy of this `LazyTractogram` object. """ - streamlines = LazyTractogram(self._streamlines, self._scalars, self._properties) - streamlines._header = self.header.copy() - return streamlines + tractogram = LazyTractogram(self._streamlines, + self._data_per_streamline, + self._data_per_point) + tractogram.nb_streamlines = self.nb_streamlines + tractogram._data = self._data + return tractogram - def transform(self, affine): + def apply_affine(self, affine): """ Applies an affine transformation on the streamlines. Parameters ---------- affine : 2D array (4,4) Transformation that will be applied on each streamline. - - Returns - ------- - streamlines : `LazyTractogram` object - Tractogram living in a space defined by `affine`. """ - return super(LazyTractogram, self).transform(affine, lazy=True) + self._affine_to_apply = np.dot(affine, self._affine_to_apply) class abstractclassmethod(classmethod): From 3b705690b9f301e5891f4914bb7642044f5c23df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 8 Nov 2015 13:23:58 -0500 Subject: [PATCH 23/54] Added CompactList to init --- nibabel/streamlines/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index cfd8990741..9c24d6f2b0 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,4 +1,5 @@ from .header import TractogramHeader +from .base_format import CompactList from .base_format import Tractogram, LazyTractogram from nibabel.streamlines.trk import TrkFile From f5f4f91176ac5dfdc0fdf116765fc00dd5ed9002 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 8 Nov 2015 13:36:03 -0500 Subject: [PATCH 24/54] Added get_streamlines method to TractogramFile --- nibabel/streamlines/base_format.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 5d5cd0e6e0..b5c9f3c03f 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -583,6 +583,9 @@ def header(self): def get_tractogram(self): return self.tractogram + def get_streamlines(self): + return self.streamlines + def get_header(self): return self.header From 4480505bb97a91539d4d754105b8ba8adb84e6b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 8 Nov 2015 14:07:01 -0500 Subject: [PATCH 25/54] Added save and load support for compact_list --- nibabel/streamlines/tests/test_utils.py | 26 +++++++++++++++++++++++++ nibabel/streamlines/utils.py | 17 ++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/nibabel/streamlines/tests/test_utils.py b/nibabel/streamlines/tests/test_utils.py index 4cceef35c4..b5c67e0770 100644 --- a/nibabel/streamlines/tests/test_utils.py +++ b/nibabel/streamlines/tests/test_utils.py @@ -1,5 +1,6 @@ import os import unittest +import tempfile import numpy as np import nibabel as nib @@ -7,7 +8,9 @@ from numpy.testing import assert_array_equal from nose.tools import assert_equal, assert_raises, assert_true +from ..base_format import CompactList from ..utils import pop, get_affine_from_reference +from ..utils import save_compact_list, load_compact_list def test_peek(): @@ -33,3 +36,26 @@ def test_get_affine_from_reference(): # Get affine from a `SpatialImage` using by its filename. assert_array_equal(get_affine_from_reference(filename), affine) + + +def test_save_and_load_compact_list(): + + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + clist = CompactList() + save_compact_list(f, clist) + f.seek(0, os.SEEK_SET) + loaded_clist = load_compact_list(f) + assert_array_equal(loaded_clist._data, clist._data) + assert_array_equal(loaded_clist._offsets, clist._offsets) + assert_array_equal(loaded_clist._lengths, clist._lengths) + + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + clist = CompactList(data) + save_compact_list(f, clist) + f.seek(0, os.SEEK_SET) + loaded_clist = load_compact_list(f) + assert_array_equal(loaded_clist._data, clist._data) + assert_array_equal(loaded_clist._offsets, clist._offsets) + assert_array_equal(loaded_clist._lengths, clist._lengths) diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 7bbbe1ef8d..a44b25fe1b 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -33,3 +33,20 @@ def pop(iterable): "Returns the next item from the iterable else None" value = list(itertools.islice(iterable, 1)) return value[0] if len(value) > 0 else None + + +def save_compact_list(filename, compact_list): + np.savez(filename, + data=compact_list._data, + offsets=compact_list._offsets, + lengths=compact_list._lengths) + + +def load_compact_list(filename): + from .base_format import CompactList + content = np.load(filename) + compact_list = CompactList() + compact_list._data = content["data"] + compact_list._offsets = content["offsets"].tolist() + compact_list._lengths = content["lengths"].tolist() + return compact_list From 051673230eb985f4f29d3cba2991c152cd822329 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 8 Nov 2015 14:08:12 -0500 Subject: [PATCH 26/54] DOC: load and save utils functions --- nibabel/streamlines/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index a44b25fe1b..7fbfab5106 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -36,6 +36,7 @@ def pop(iterable): def save_compact_list(filename, compact_list): + """ Saves a `CompactList` object to a .npz file. """ np.savez(filename, data=compact_list._data, offsets=compact_list._offsets, @@ -43,6 +44,7 @@ def save_compact_list(filename, compact_list): def load_compact_list(filename): + """ Loads a `CompactList` object from a .npz file. """ from .base_format import CompactList content = np.load(filename) compact_list = CompactList() From ba9e3f4e3e617290584a100ffa973204e8f1f19d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 11 Nov 2015 10:21:22 -0500 Subject: [PATCH 27/54] Refactored streamlines API and added unit tests for LazyTractogram --- nibabel/streamlines/__init__.py | 4 +- nibabel/streamlines/base_format.py | 640 ------------------ nibabel/streamlines/compact_list.py | 198 ++++++ .../streamlines/tests/test_compact_list.py | 222 ++++++ ...test_base_format.py => test_tractogram.py} | 315 ++------- nibabel/streamlines/tests/test_utils.py | 27 - nibabel/streamlines/tractogram.py | 390 +++++++++++ nibabel/streamlines/tractogram_file.py | 103 +++ nibabel/streamlines/trk.py | 17 +- nibabel/streamlines/utils.py | 19 - 10 files changed, 989 insertions(+), 946 deletions(-) create mode 100644 nibabel/streamlines/compact_list.py create mode 100644 nibabel/streamlines/tests/test_compact_list.py rename nibabel/streamlines/tests/{test_base_format.py => test_tractogram.py} (52%) create mode 100644 nibabel/streamlines/tractogram.py create mode 100644 nibabel/streamlines/tractogram_file.py diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 9c24d6f2b0..abf9ca27f3 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,6 +1,6 @@ from .header import TractogramHeader -from .base_format import CompactList -from .base_format import Tractogram, LazyTractogram +from .compact_list import CompactList +from .tractogram import Tractogram, LazyTractogram from nibabel.streamlines.trk import TrkFile #from nibabel.streamlines.tck import TckFile diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index b5c9f3c03f..e694e7e0c4 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -1,18 +1,3 @@ -import itertools -import numpy as np -from warnings import warn - -from abc import ABCMeta, abstractmethod, abstractproperty - -from nibabel.externals.six.moves import zip_longest -from nibabel.affines import apply_affine - -from .header import TractogramHeader -from .utils import pop - - -class UsageWarning(Warning): - pass class HeaderWarning(Warning): @@ -25,628 +10,3 @@ class HeaderError(Exception): class DataError(Exception): pass - - -class CompactList(object): - """ Class for compacting list of ndarrays with matching shape except for - the first dimension. - """ - def __init__(self, iterable=None): - """ - Parameters - ---------- - iterable : iterable (optional) - If specified, create a ``CompactList`` object initialized from - iterable's items. Otherwise, create an empty ``CompactList``. - - Notes - ----- - If `iterable` is a ``CompactList`` object, a view is returned and no - memory is allocated. For an actual copy use the `.copy()` method. - """ - # Create new empty `CompactList` object. - self._data = None - self._offsets = [] - self._lengths = [] - - if isinstance(iterable, CompactList): - # Create a view. - self._data = iterable._data - self._offsets = iterable._offsets - self._lengths = iterable._lengths - - elif iterable is not None: - # Initialize the `CompactList` object from iterable's item. - BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. - - offset = 0 - for i, e in enumerate(iterable): - e = np.asarray(e) - if i == 0: - self._data = np.empty((BUFFER_SIZE,) + e.shape[1:], dtype=e.dtype) - - end = offset + len(e) - if end >= len(self._data): - # Resize is needed (at least `len(e)` items will be added). - self._data.resize((len(self._data) + len(e)+BUFFER_SIZE,) + self.shape) - - self._offsets.append(offset) - self._lengths.append(len(e)) - self._data[offset:offset+len(e)] = e - offset += len(e) - - # Clear unused memory. - if self._data is not None: - self._data.resize((offset,) + self.shape) - - @property - def shape(self): - """ Returns the matching shape of the elements in this compact list. """ - if self._data is None: - return None - - return self._data.shape[1:] - - def append(self, element): - """ Appends `element` to this compact list. - - Parameters - ---------- - element : ndarray - Element to append. The shape must match already inserted elements - shape except for the first dimension. - - Notes - ----- - If you need to add multiple elements you should consider - `CompactList.extend`. - """ - if self._data is None: - self._data = np.asarray(element).copy() - self._offsets.append(0) - self._lengths.append(len(element)) - return - - if element.shape[1:] != self.shape: - raise ValueError("All dimensions, except the first one, must match exactly") - - self._offsets.append(len(self._data)) - self._lengths.append(len(element)) - self._data = np.append(self._data, element, axis=0) - - def extend(self, elements): - """ Appends all `elements` to this compact list. - - Parameters - ---------- - element : list of ndarrays, ``CompactList`` object - Elements to append. The shape must match already inserted elements - shape except for the first dimension. - """ - if isinstance(elements, CompactList): - self._data = np.concatenate([self._data, elements._data], axis=0) - offset = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 - self._lengths.extend(elements._lengths) - self._offsets.extend(np.cumsum([offset] + elements._lengths).tolist()[:-1]) - else: - self._data = np.concatenate([self._data] + list(elements), axis=0) - offset = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 - lengths = map(len, elements) - self._lengths.extend(lengths) - self._offsets.extend(np.cumsum([offset] + lengths).tolist()[:-1]) - - def copy(self): - """ Creates a copy of this ``CompactList`` object. """ - # We cannot just deepcopy this object since we don't know if it has been created - # using slicing. If it is the case, `self.data` probably contains more data than necessary - # so we copy only elements according to `self._offsets`. - compact_list = CompactList() - total_lengths = np.sum(self._lengths) - compact_list._data = np.empty((total_lengths,) + self._data.shape[1:], dtype=self._data.dtype) - - cur_offset = 0 - for offset, lengths in zip(self._offsets, self._lengths): - compact_list._offsets.append(cur_offset) - compact_list._lengths.append(lengths) - compact_list._data[cur_offset:cur_offset+lengths] = self._data[offset:offset+lengths] - cur_offset += lengths - - return compact_list - - def __getitem__(self, idx): - """ Gets element(s) through indexing. - - Parameters - ---------- - idx : int, slice or list - Index of the element(s) to get. - - Returns - ------- - ndarray object(s) - When `idx` is a int, returns a single ndarray. - When `idx` is either a slice or a list, returns a list of ndarrays. - """ - if isinstance(idx, int) or isinstance(idx, np.integer): - return self._data[self._offsets[idx]:self._offsets[idx]+self._lengths[idx]] - - elif type(idx) is slice: - # TODO: Should we have a CompactListView class that would be - # returned when slicing? - compact_list = CompactList() - compact_list._data = self._data - compact_list._offsets = self._offsets[idx] - compact_list._lengths = self._lengths[idx] - return compact_list - - elif type(idx) is list: - # TODO: Should we have a CompactListView class that would be - # returned when doing advance indexing? - compact_list = CompactList() - compact_list._data = self._data - compact_list._offsets = [self._offsets[i] for i in idx] - compact_list._lengths = [self._lengths[i] for i in idx] - return compact_list - - raise TypeError("Index must be a int or a slice! Not " + str(type(idx))) - - def __iter__(self): - if len(self._lengths) != len(self._offsets): - raise ValueError("CompactList object corrupted: len(self._lengths) != len(self._offsets)") - - for offset, lengths in zip(self._offsets, self._lengths): - yield self._data[offset: offset+lengths] - - def __len__(self): - return len(self._offsets) - - def __repr__(self): - return repr(list(self)) - - -class TractogramItem(object): - """ Class containing information about one streamline. - - ``TractogramItem`` objects have three main properties: `streamline`, - `data_for_streamline`, and `data_for_points`. - - Parameters - ---------- - streamline : ndarray of shape (N, 3) - Points of this streamline represented as an ndarray of shape (N, 3) - where N is the number of points. - - data_for_streamline : dict - - data_for_points : dict - """ - def __init__(self, streamline, data_for_streamline, data_for_points): - self.streamline = np.asarray(streamline) - self.data_for_streamline = data_for_streamline - self.data_for_points = data_for_points - - def __iter__(self): - return iter(self.streamline) - - def __len__(self): - return len(self.streamline) - - -class Tractogram(object): - """ Class containing information about streamlines. - - Tractogram objects have three main properties: ``streamlines`` - - Parameters - ---------- - streamlines : list of ndarray of shape (Nt, 3) - Sequence of T streamlines. One streamline is an ndarray of shape - (Nt, 3) where Nt is the number of points of streamline t. - - data_per_streamline : dictionary of list of ndarray of shape (P,) - Sequence of T ndarrays of shape (P,) where T is the number of - streamlines defined by ``streamlines``, P is the number of properties - associated to each streamline. - - data_per_point : dictionary of list of ndarray of shape (Nt, M) - Sequence of T ndarrays of shape (Nt, M) where T is the number of - streamlines defined by ``streamlines``, Nt is the number of points - for a particular streamline t and M is the number of scalars - associated to each point (excluding the three coordinates). - - """ - def __init__(self, streamlines=None, - data_per_streamline=None, - data_per_point=None): - - self.streamlines = streamlines - self.data_per_streamline = data_per_streamline - self.data_per_point = data_per_point - - @property - def streamlines(self): - return self._streamlines - - @streamlines.setter - def streamlines(self, value): - self._streamlines = CompactList(value) - - @property - def data_per_streamline(self): - return self._data_per_streamline - - @data_per_streamline.setter - def data_per_streamline(self, value): - if value is None: - value = {} - - self._data_per_streamline = {} - for k, v in value.items(): - self._data_per_streamline[k] = np.asarray(v) - - @property - def data_per_point(self): - return self._data_per_point - - @data_per_point.setter - def data_per_point(self, value): - if value is None: - value = {} - - self._data_per_point = {} - for k, v in value.items(): - self._data_per_point[k] = CompactList(v) - - def __iter__(self): - for i in range(len(self.streamlines)): - yield self[i] - - def __getitem__(self, idx): - pts = self.streamlines[idx] - - new_data_per_streamline = {} - for key in self.data_per_streamline: - new_data_per_streamline[key] = self.data_per_streamline[key][idx] - - new_data_per_point = {} - for key in self.data_per_point: - new_data_per_point[key] = self.data_per_point[key][idx] - - if type(idx) is slice: - return Tractogram(pts, new_data_per_streamline, new_data_per_point) - - return TractogramItem(pts, new_data_per_streamline, new_data_per_point) - - def __len__(self): - return len(self.streamlines) - - def copy(self): - """ Returns a copy of this `Tractogram` object. """ - new_data_per_streamline = {} - for key in self.data_per_streamline: - new_data_per_streamline[key] = self.data_per_streamline[key].copy() - - new_data_per_point = {} - for key in self.data_per_point: - new_data_per_point[key] = self.data_per_point[key].copy() - - tractogram = Tractogram(self.streamlines.copy(), - new_data_per_streamline, - new_data_per_point) - return tractogram - - def apply_affine(self, affine): - """ Applies an affine transformation on the points of each streamline. - - This is performed in-place. - - Parameters - ---------- - affine : 2D array (4,4) - Transformation that will be applied on each streamline. - """ - if len(self.streamlines) == 0: - return - - BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. - for i in range(0, len(self.streamlines._data), BUFFER_SIZE): - pts = self.streamlines._data[i:i+BUFFER_SIZE] - self.streamlines._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) - - -class LazyTractogram(Tractogram): - ''' Class containing information about streamlines. - - Tractogram objects have four main properties: ``header``, ``streamlines``, - ``scalars`` and ``properties``. Tractogram objects are iterable and - produce tuple of ``streamlines``, ``scalars`` and ``properties`` for each - streamline. - - Parameters - ---------- - streamlines_func : coroutine ouputting (Nt,3) array-like (optional) - Function yielding streamlines. One streamline is - an ndarray of shape (Nt,3) where Nt is the number of points of - streamline t. - - scalars_func : coroutine ouputting (Nt,M) array-like (optional) - Function yielding scalars for a particular streamline t. The scalars - are represented as an ndarray of shape (Nt,M) where Nt is the number - of points of that streamline t and M is the number of scalars - associated to each point (excluding the three coordinates). - - properties_func : coroutine ouputting (P,) array-like (optional) - Function yielding properties for a particular streamline t. The - properties are represented as an ndarray of shape (P,) where P is - the number of properties associated to each streamline. - - getitem_func : function `idx -> 3-tuples` (optional) - Function returning a subset of the tractogram given an index or a - slice (i.e. the __getitem__ function to use). - - Notes - ----- - If provided, ``scalars`` and ``properties`` must yield the same number of - values as ``streamlines``. - ''' - def __init__(self, streamlines=lambda:[], data_per_streamline=None, data_per_point=None): - super(LazyTractogram, self).__init__(streamlines, data_per_streamline, data_per_point) - self.nb_streamlines = None - self._data = None - self._getitem = None - self._affine_to_apply = np.eye(4) - - @classmethod - def create_from_data(cls, data_func): - ''' Saves streamlines to a file-like object. - - Parameters - ---------- - data_func : coroutine ouputting tuple (optional) - Function yielding 3-tuples, (streamlines, scalars, properties). - Streamlines are represented as an ndarray of shape (Nt,3), scalars - as an ndarray of shape (Nt,M) and properties as an ndarray of shape - (P,) where Nt is the number of points for a particular - streamline t, M is the number of scalars associated to each point - (excluding the three coordinates) and P is the number of properties - associated to each streamline. - ''' - if not callable(data_func): - raise TypeError("`data` must be a coroutine.") - - lazy_streamlines = cls() - lazy_streamlines._data = data_func - lazy_streamlines.streamlines = lambda: (x[0] for x in data_func()) - lazy_streamlines.scalars = lambda: (x[1] for x in data_func()) - lazy_streamlines.properties = lambda: (x[2] for x in data_func()) - return lazy_streamlines - - @property - def streamlines(self): - if not np.all(self._affine_to_apply == np.eye(4)): - def _transform(): - for s in self._streamlines(): - yield apply_affine(self._affine_to_apply, s) - - return _transform() - - return self._streamlines() - - @streamlines.setter - def streamlines(self, value): - if not callable(value): - raise TypeError("`streamlines` must be a coroutine.") - - self._streamlines = value - - @property - def data_per_streamline(self): - return self._data_per_streamline - - @data_per_streamline.setter - def data_per_streamline(self, value): - if value is None: - value = {} - - self._data_per_streamline = {} - for k, v in value.items(): - if not callable(v): - raise TypeError("`data_per_streamline` must be a dict of coroutines.") - - self._data_per_streamline[k] = v - - @property - def data_per_point(self): - return self._data_per_point - - @data_per_point.setter - def data_per_point(self, value): - if value is None: - value = {} - - self._data_per_point = {} - for k, v in value.items(): - if not callable(v): - raise TypeError("`data_per_point` must be a dict of coroutines.") - - self._data_per_point[k] = v - - @property - def data(self): - if self._data is not None: - return self._data() - - def _gen_data(): - data_per_streamline_generators = {} - for k, v in self.data_per_streamline.items(): - data_per_streamline_generators[k] = iter(v()) - - data_per_point_generators = {} - for k, v in self.data_per_point.items(): - data_per_point_generators[k] = iter(v()) - - for s in self.streamlines: - data_for_streamline = {} - for k, v in data_per_streamline_generators.items(): - data_for_streamline[k] = next(v) - - data_for_points = {} - for k, v in data_per_point_generators.items(): - data_for_points[k] = v() - - yield TractogramItem(s, data_for_streamline, data_for_points) - - return _gen_data() - - @data.setter - def data(self, value): - if not callable(value): - raise TypeError("`data` must be a coroutine.") - - self._data = value - - def __getitem__(self, idx): - if self._getitem is None: - raise AttributeError('`LazyTractogram` does not support indexing.') - - return self._getitem(idx) - - def __iter__(self): - i = 0 - for i, tractogram_item in enumerate(self.data, start=1): - yield tractogram_item - - # To be safe, update information about number of streamlines. - self.nb_streamlines = i - - def __len__(self): - # If length is unknown, we obtain it by iterating through streamlines. - if self.nb_streamlines is None: - warn("Number of streamlines will be determined manually by looping" - " through the streamlines. If you know the actual number of" - " streamlines, you might want to set it beforehand via" - " `self.header.nb_streamlines`." - " Note this will consume any generators used to create this" - " `LazyTractogram` object.", UsageWarning) - return sum(1 for _ in self.streamlines) - - return self.nb_streamlines - - def copy(self): - """ Returns a copy of this `LazyTractogram` object. """ - tractogram = LazyTractogram(self._streamlines, - self._data_per_streamline, - self._data_per_point) - tractogram.nb_streamlines = self.nb_streamlines - tractogram._data = self._data - return tractogram - - def apply_affine(self, affine): - """ Applies an affine transformation on the streamlines. - - Parameters - ---------- - affine : 2D array (4,4) - Transformation that will be applied on each streamline. - """ - self._affine_to_apply = np.dot(affine, self._affine_to_apply) - - -class abstractclassmethod(classmethod): - __isabstractmethod__ = True - - def __init__(self, callable): - callable.__isabstractmethod__ = True - super(abstractclassmethod, self).__init__(callable) - - -class TractogramFile(object): - ''' Convenience class to encapsulate tractogram file format. ''' - __metaclass__ = ABCMeta - - def __init__(self, tractogram, header=None): - self._tractogram = tractogram - self._header = TractogramHeader() if header is None else header - - @property - def tractogram(self): - return self._tractogram - - @property - def streamlines(self): - return self.tractogram.streamlines - - @property - def header(self): - return self._header - - def get_tractogram(self): - return self.tractogram - - def get_streamlines(self): - return self.streamlines - - def get_header(self): - return self.header - - @classmethod - def get_magic_number(cls): - ''' Returns streamlines file's magic number. ''' - raise NotImplementedError() - - @classmethod - def support_data_per_point(cls): - ''' Tells if this tractogram format supports saving data per point. ''' - raise NotImplementedError() - - @classmethod - def support_data_per_streamline(cls): - ''' Tells if this tractogram format supports saving data per streamline. ''' - raise NotImplementedError() - - @classmethod - def is_correct_format(cls, fileobj): - ''' Checks if the file has the right streamlines file format. - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the - beginning of the header). - Returns - ------- - is_correct_format : boolean - Returns True if `fileobj` is in the right streamlines file format. - ''' - raise NotImplementedError() - - @abstractclassmethod - def load(cls, fileobj, lazy_load=True): - ''' Loads streamlines from a file-like object. - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the - beginning of the header). - lazy_load : boolean (optional) - Load streamlines in a lazy manner i.e. they will not be kept - in memory. For postprocessing speed, turn off this option. - Returns - ------- - tractogram_file : ``TractogramFile`` object - Returns an object containing tractogram data and header - information. - ''' - raise NotImplementedError() - - @abstractmethod - def save(self, fileobj): - ''' Saves streamlines to a file-like object. - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - opened and ready to write. - ''' - raise NotImplementedError() diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py new file mode 100644 index 0000000000..87671bb5c2 --- /dev/null +++ b/nibabel/streamlines/compact_list.py @@ -0,0 +1,198 @@ +import numpy as np + + +class CompactList(object): + """ Class for compacting list of ndarrays with matching shape except for + the first dimension. + """ + def __init__(self, iterable=None): + """ + Parameters + ---------- + iterable : iterable (optional) + If specified, create a ``CompactList`` object initialized from + iterable's items. Otherwise, create an empty ``CompactList``. + + Notes + ----- + If `iterable` is a ``CompactList`` object, a view is returned and no + memory is allocated. For an actual copy use the `.copy()` method. + """ + # Create new empty `CompactList` object. + self._data = None + self._offsets = [] + self._lengths = [] + + if isinstance(iterable, CompactList): + # Create a view. + self._data = iterable._data + self._offsets = iterable._offsets + self._lengths = iterable._lengths + + elif iterable is not None: + # Initialize the `CompactList` object from iterable's item. + BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. + + offset = 0 + for i, e in enumerate(iterable): + e = np.asarray(e) + if i == 0: + self._data = np.empty((BUFFER_SIZE,) + e.shape[1:], + dtype=e.dtype) + + end = offset + len(e) + if end >= len(self._data): + # Resize is needed (at least `len(e)` items will be added). + self._data.resize((len(self._data) + len(e)+BUFFER_SIZE,) + + self.shape) + + self._offsets.append(offset) + self._lengths.append(len(e)) + self._data[offset:offset+len(e)] = e + offset += len(e) + + # Clear unused memory. + if self._data is not None: + self._data.resize((offset,) + self.shape) + + @property + def shape(self): + """ Returns the matching shape of the elements in this compact list. """ + if self._data is None: + return None + + return self._data.shape[1:] + + def append(self, element): + """ Appends `element` to this compact list. + + Parameters + ---------- + element : ndarray + Element to append. The shape must match already inserted elements + shape except for the first dimension. + + Notes + ----- + If you need to add multiple elements you should consider + `CompactList.extend`. + """ + if self._data is None: + self._data = np.asarray(element).copy() + self._offsets.append(0) + self._lengths.append(len(element)) + return + + if element.shape[1:] != self.shape: + raise ValueError("All dimensions, except the first one," + " must match exactly") + + self._offsets.append(len(self._data)) + self._lengths.append(len(element)) + self._data = np.append(self._data, element, axis=0) + + def extend(self, elements): + """ Appends all `elements` to this compact list. + + Parameters + ---------- + elements : list of ndarrays, ``CompactList`` object + Elements to append. The shape must match already inserted elements + shape except for the first dimension. + + """ + if isinstance(elements, CompactList): + self._data = np.concatenate([self._data, elements._data], axis=0) + lengths = elements._lengths + else: + self._data = np.concatenate([self._data] + list(elements), axis=0) + lengths = map(len, elements) + + idx = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 + self._lengths.extend(lengths) + self._offsets.extend(np.cumsum([idx] + lengths).tolist()[:-1]) + + def copy(self): + """ Creates a copy of this ``CompactList`` object. """ + # We do not simply deepcopy this object since we might have a chance + # to use less memory. For example, if the compact list being copied + # is the result of a slicing operation on a compact list. + clist = CompactList() + total_lengths = np.sum(self._lengths) + clist._data = np.empty((total_lengths,) + self._data.shape[1:], + dtype=self._data.dtype) + + idx = 0 + for offset, length in zip(self._offsets, self._lengths): + clist._offsets.append(idx) + clist._lengths.append(length) + clist._data[idx:idx+length] = self._data[offset:offset+length] + idx += length + + return clist + + def __getitem__(self, idx): + """ Gets element(s) through indexing. + + Parameters + ---------- + idx : int, slice or list + Index of the element(s) to get. + + Returns + ------- + ndarray object(s) + When `idx` is a int, returns a single ndarray. + When `idx` is either a slice or a list, returns a list of ndarrays. + """ + if isinstance(idx, int) or isinstance(idx, np.integer): + start = self._offsets[idx] + return self._data[start:start+self._lengths[idx]] + + elif type(idx) is slice: + compact_list = CompactList() + compact_list._data = self._data + compact_list._offsets = self._offsets[idx] + compact_list._lengths = self._lengths[idx] + return compact_list + + elif type(idx) is list: + compact_list = CompactList() + compact_list._data = self._data + compact_list._offsets = [self._offsets[i] for i in idx] + compact_list._lengths = [self._lengths[i] for i in idx] + return compact_list + + raise TypeError("Index must be a int or a slice! Not " + str(type(idx))) + + def __iter__(self): + if len(self._lengths) != len(self._offsets): + raise ValueError("CompactList object corrupted:" + " len(self._lengths) != len(self._offsets)") + + for offset, lengths in zip(self._offsets, self._lengths): + yield self._data[offset: offset+lengths] + + def __len__(self): + return len(self._offsets) + + def __repr__(self): + return repr(list(self)) + + +def save_compact_list(filename, compact_list): + """ Saves a `CompactList` object to a .npz file. """ + np.savez(filename, + data=compact_list._data, + offsets=compact_list._offsets, + lengths=compact_list._lengths) + + +def load_compact_list(filename): + """ Loads a `CompactList` object from a .npz file. """ + content = np.load(filename) + compact_list = CompactList() + compact_list._data = content["data"] + compact_list._offsets = content["offsets"].tolist() + compact_list._lengths = content["lengths"].tolist() + return compact_list diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py new file mode 100644 index 0000000000..2211d9d303 --- /dev/null +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -0,0 +1,222 @@ +import os +import unittest +import tempfile +import numpy as np + +from nose.tools import assert_equal, assert_raises, assert_true +from numpy.testing import assert_array_equal +from nibabel.externals.six.moves import zip, zip_longest + +from ..compact_list import (CompactList, + load_compact_list, + save_compact_list) + + +class TestCompactList(unittest.TestCase): + + def setUp(self): + rng = np.random.RandomState(42) + self.data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + self.lengths = map(len, self.data) + self.clist = CompactList(self.data) + + def test_creating_empty_compactlist(self): + clist = CompactList() + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_list(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = map(len, data) + + clist = CompactList(data) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + + # Empty list + clist = CompactList([]) + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_generator(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = map(len, data) + + gen = (e for e in data) + clist = CompactList(gen) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + + # Already consumed generator + clist = CompactList(gen) + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_compact_list(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = map(len, data) + + clist = CompactList(data) + clist2 = CompactList(clist) + assert_equal(len(clist2), len(data)) + assert_equal(len(clist2._offsets), len(data)) + assert_equal(len(clist2._lengths), len(data)) + assert_equal(clist2._data.shape[0], sum(lengths)) + assert_equal(clist2._data.shape[1], 3) + assert_equal(clist2._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist2._lengths, lengths) + assert_equal(clist2.shape, data[0].shape[1:]) + + def test_compactlist_iter(self): + for e, d in zip(self.clist, self.data): + assert_array_equal(e, d) + + def test_compactlist_copy(self): + clist = self.clist.copy() + assert_array_equal(clist._data, self.clist._data) + assert_true(clist._data is not self.clist._data) + assert_array_equal(clist._offsets, self.clist._offsets) + assert_true(clist._offsets is not self.clist._offsets) + assert_array_equal(clist._lengths, self.clist._lengths) + assert_true(clist._lengths is not self.clist._lengths) + + assert_equal(clist.shape, self.clist.shape) + + # When taking a copy of a `CompactList` generated by slicing. + # Only needed data should be kept. + clist = self.clist[::2].copy() + + assert_true(clist._data.shape[0] < self.clist._data.shape[0]) + assert_true(len(clist) < len(self.clist)) + assert_true(clist._data is not self.clist._data) + + def test_compactlist_append(self): + # Maybe not necessary if `self.setUp` is always called before a + # test method, anyways create a copy just in case. + clist = self.clist.copy() + + rng = np.random.RandomState(1234) + element = rng.rand(rng.randint(10, 50), *self.clist.shape) + clist.append(element) + assert_equal(len(clist), len(self.clist)+1) + assert_equal(clist._offsets[-1], len(self.clist._data)) + assert_equal(clist._lengths[-1], len(element)) + assert_array_equal(clist._data[-len(element):], element) + + # Append with different shape. + element = rng.rand(rng.randint(10, 50), 42) + assert_raises(ValueError, clist.append, element) + + # Append to an empty CompactList. + clist = CompactList() + rng = np.random.RandomState(1234) + shape = (2, 3, 4) + element = rng.rand(rng.randint(10, 50), *shape) + clist.append(element) + + assert_equal(len(clist), 1) + assert_equal(clist._offsets[-1], 0) + assert_equal(clist._lengths[-1], len(element)) + assert_array_equal(clist._data, element) + assert_equal(clist.shape, shape) + + def test_compactlist_extend(self): + # Maybe not necessary if `self.setUp` is always called before a + # test method, anyways create a copy just in case. + clist = self.clist.copy() + + rng = np.random.RandomState(1234) + shape = self.clist.shape + new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(5)] + lengths = map(len, new_data) + clist.extend(new_data) + assert_equal(len(clist), len(self.clist)+len(new_data)) + assert_array_equal(clist._offsets[-len(new_data):], + len(self.clist._data) + np.cumsum([0] + lengths[:-1])) + + assert_equal(clist._lengths[-len(new_data):], lengths) + assert_array_equal(clist._data[-sum(lengths):], + np.concatenate(new_data, axis=0)) + + # Extend with another `CompactList` object. + clist = self.clist.copy() + new_data = CompactList(new_data) + clist.extend(new_data) + assert_equal(len(clist), len(self.clist)+len(new_data)) + assert_array_equal(clist._offsets[-len(new_data):], + len(self.clist._data) + np.cumsum([0] + lengths[:-1])) + + assert_equal(clist._lengths[-len(new_data):], lengths) + assert_array_equal(clist._data[-sum(lengths):], new_data._data) + + def test_compactlist_getitem(self): + # Get one item + for i, e in enumerate(self.clist): + assert_array_equal(self.clist[i], e) + + # Get multiple items (this will create a view). + clist_view = self.clist[range(len(self.clist))] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_true(clist_view._offsets is not self.clist._offsets) + assert_true(clist_view._lengths is not self.clist._lengths) + assert_array_equal(clist_view._offsets, self.clist._offsets) + assert_array_equal(clist_view._lengths, self.clist._lengths) + for e1, e2 in zip_longest(clist_view, self.clist): + assert_array_equal(e1, e2) + + # Get slice (this will create a view). + clist_view = self.clist[::2] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_array_equal(clist_view._offsets, self.clist._offsets[::2]) + assert_array_equal(clist_view._lengths, self.clist._lengths[::2]) + for i, e in enumerate(clist_view): + assert_array_equal(e, self.clist[i*2]) + + +def test_save_and_load_compact_list(): + + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + clist = CompactList() + save_compact_list(f, clist) + f.seek(0, os.SEEK_SET) + loaded_clist = load_compact_list(f) + assert_array_equal(loaded_clist._data, clist._data) + assert_array_equal(loaded_clist._offsets, clist._offsets) + assert_array_equal(loaded_clist._lengths, clist._lengths) + + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + clist = CompactList(data) + save_compact_list(f, clist) + f.seek(0, os.SEEK_SET) + loaded_clist = load_compact_list(f) + assert_array_equal(loaded_clist._data, clist._data) + assert_array_equal(loaded_clist._offsets, clist._offsets) + assert_array_equal(loaded_clist._lengths, clist._lengths) diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_tractogram.py similarity index 52% rename from nibabel/streamlines/tests/test_base_format.py rename to nibabel/streamlines/tests/test_tractogram.py index 53d8589ab7..f106e5e08b 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -1,4 +1,3 @@ -import os import unittest import numpy as np import warnings @@ -6,202 +5,12 @@ from nibabel.testing import assert_arrays_equal, isiterable from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true -from numpy.testing import assert_array_equal, assert_array_almost_equal -from nibabel.externals.six.moves import zip, zip_longest +from numpy.testing import assert_array_equal +from nibabel.externals.six.moves import zip -from .. import base_format -from ..base_format import CompactList -from ..base_format import TractogramItem, Tractogram, LazyTractogram -from ..base_format import UsageWarning - -DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') - - -class TestCompactList(unittest.TestCase): - - def setUp(self): - rng = np.random.RandomState(42) - self.data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - self.lengths = map(len, self.data) - self.clist = CompactList(self.data) - - def test_creating_empty_compactlist(self): - clist = CompactList() - assert_equal(len(clist), 0) - assert_equal(len(clist._offsets), 0) - assert_equal(len(clist._lengths), 0) - assert_true(clist._data is None) - assert_true(clist.shape is None) - - def test_creating_compactlist_from_list(self): - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = map(len, data) - - clist = CompactList(data) - assert_equal(len(clist), len(data)) - assert_equal(len(clist._offsets), len(data)) - assert_equal(len(clist._lengths), len(data)) - assert_equal(clist._data.shape[0], sum(lengths)) - assert_equal(clist._data.shape[1], 3) - assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) - assert_equal(clist._lengths, lengths) - assert_equal(clist.shape, data[0].shape[1:]) - - # Empty list - clist = CompactList([]) - assert_equal(len(clist), 0) - assert_equal(len(clist._offsets), 0) - assert_equal(len(clist._lengths), 0) - assert_true(clist._data is None) - assert_true(clist.shape is None) - - def test_creating_compactlist_from_generator(self): - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = map(len, data) - - gen = (e for e in data) - clist = CompactList(gen) - assert_equal(len(clist), len(data)) - assert_equal(len(clist._offsets), len(data)) - assert_equal(len(clist._lengths), len(data)) - assert_equal(clist._data.shape[0], sum(lengths)) - assert_equal(clist._data.shape[1], 3) - assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) - assert_equal(clist._lengths, lengths) - assert_equal(clist.shape, data[0].shape[1:]) - - # Already consumed generator - clist = CompactList(gen) - assert_equal(len(clist), 0) - assert_equal(len(clist._offsets), 0) - assert_equal(len(clist._lengths), 0) - assert_true(clist._data is None) - assert_true(clist.shape is None) - - def test_creating_compactlist_from_compact_list(self): - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = map(len, data) - - clist = CompactList(data) - clist2 = CompactList(clist) - assert_equal(len(clist2), len(data)) - assert_equal(len(clist2._offsets), len(data)) - assert_equal(len(clist2._lengths), len(data)) - assert_equal(clist2._data.shape[0], sum(lengths)) - assert_equal(clist2._data.shape[1], 3) - assert_equal(clist2._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) - assert_equal(clist2._lengths, lengths) - assert_equal(clist2.shape, data[0].shape[1:]) - - def test_compactlist_iter(self): - for e, d in zip(self.clist, self.data): - assert_array_equal(e, d) - - def test_compactlist_copy(self): - clist = self.clist.copy() - assert_array_equal(clist._data, self.clist._data) - assert_true(clist._data is not self.clist._data) - assert_array_equal(clist._offsets, self.clist._offsets) - assert_true(clist._offsets is not self.clist._offsets) - assert_array_equal(clist._lengths, self.clist._lengths) - assert_true(clist._lengths is not self.clist._lengths) - - assert_equal(clist.shape, self.clist.shape) - - # When taking a copy of a `CompactList` generated by slicing. - # Only needed data should be kept. - clist = self.clist[::2].copy() - - assert_true(clist._data.shape[0] < self.clist._data.shape[0]) - assert_true(len(clist) < len(self.clist)) - assert_true(clist._data is not self.clist._data) - - def test_compactlist_append(self): - # Maybe not necessary if `self.setUp` is always called before a - # test method, anyways create a copy just in case. - clist = self.clist.copy() - - rng = np.random.RandomState(1234) - element = rng.rand(rng.randint(10, 50), *self.clist.shape) - clist.append(element) - assert_equal(len(clist), len(self.clist)+1) - assert_equal(clist._offsets[-1], len(self.clist._data)) - assert_equal(clist._lengths[-1], len(element)) - assert_array_equal(clist._data[-len(element):], element) - - # Append with different shape. - element = rng.rand(rng.randint(10, 50), 42) - assert_raises(ValueError, clist.append, element) - - # Append to an empty CompactList. - clist = CompactList() - rng = np.random.RandomState(1234) - shape = (2, 3, 4) - element = rng.rand(rng.randint(10, 50), *shape) - clist.append(element) - - assert_equal(len(clist), 1) - assert_equal(clist._offsets[-1], 0) - assert_equal(clist._lengths[-1], len(element)) - assert_array_equal(clist._data, element) - assert_equal(clist.shape, shape) - - def test_compactlist_extend(self): - # Maybe not necessary if `self.setUp` is always called before a - # test method, anyways create a copy just in case. - clist = self.clist.copy() - - rng = np.random.RandomState(1234) - shape = self.clist.shape - new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(5)] - lengths = map(len, new_data) - clist.extend(new_data) - assert_equal(len(clist), len(self.clist)+len(new_data)) - assert_array_equal(clist._offsets[-len(new_data):], - len(self.clist._data) + np.cumsum([0] + lengths[:-1])) - - assert_equal(clist._lengths[-len(new_data):], lengths) - assert_array_equal(clist._data[-sum(lengths):], - np.concatenate(new_data, axis=0)) - - # Extend with another `CompactList` object. - clist = self.clist.copy() - new_data = CompactList(new_data) - clist.extend(new_data) - assert_equal(len(clist), len(self.clist)+len(new_data)) - assert_array_equal(clist._offsets[-len(new_data):], - len(self.clist._data) + np.cumsum([0] + lengths[:-1])) - - assert_equal(clist._lengths[-len(new_data):], lengths) - assert_array_equal(clist._data[-sum(lengths):], new_data._data) - - def test_compactlist_getitem(self): - # Get one item - for i, e in enumerate(self.clist): - assert_array_equal(self.clist[i], e) - - # Get multiple items (this will create a view). - clist_view = self.clist[range(len(self.clist))] - assert_true(clist_view is not self.clist) - assert_true(clist_view._data is self.clist._data) - assert_true(clist_view._offsets is not self.clist._offsets) - assert_true(clist_view._lengths is not self.clist._lengths) - assert_array_equal(clist_view._offsets, self.clist._offsets) - assert_array_equal(clist_view._lengths, self.clist._lengths) - for e1, e2 in zip_longest(clist_view, self.clist): - assert_array_equal(e1, e2) - - # Get slice (this will create a view). - clist_view = self.clist[::2] - assert_true(clist_view is not self.clist) - assert_true(clist_view._data is self.clist._data) - assert_array_equal(clist_view._offsets, self.clist._offsets[::2]) - assert_array_equal(clist_view._lengths, self.clist._lengths[::2]) - for i, e in enumerate(clist_view): - assert_array_equal(e, self.clist[i*2]) +from .. import tractogram as module_tractogram +from ..tractogram import UsageWarning +from ..tractogram import TractogramItem, Tractogram, LazyTractogram class TestTractogramItem(unittest.TestCase): @@ -247,7 +56,7 @@ def setUp(self): [0, 0, 1], [1, 0, 0]], dtype="f4") - self.nb_tractogram = len(self.streamlines) + self.nb_streamlines = len(self.streamlines) def test_tractogram_creation(self): # Create an empty tractogram. @@ -379,10 +188,6 @@ def test_tractogram_add_new_data(self): class TestLazyTractogram(unittest.TestCase): def setUp(self): - self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") - self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") - self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") - self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), np.arange(2*3, dtype="f4").reshape((2, 3)), np.arange(5*3, dtype="f4").reshape((5, 3))] @@ -391,71 +196,81 @@ def setUp(self): np.array([(0, 1, 0)]*2, dtype="f4"), np.array([(0, 0, 1)]*5, dtype="f4")] - self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), - np.array([2.11, 2.22], dtype="f4"), - np.array([3.11, 3.22], dtype="f4")] + self.mean_curvature = np.array([1.11, 2.11, 3.11], dtype="f4") + self.mean_color = np.array([[0, 1, 0], + [0, 0, 1], + [1, 0, 0]], dtype="f4") self.nb_streamlines = len(self.streamlines) - self.nb_scalars_per_point = self.colors[0].shape[1] - self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) def test_lazy_tractogram_creation(self): # To create tractogram from arrays use `Tractogram`. assert_raises(TypeError, LazyTractogram, self.streamlines) - # Points, scalars and properties + # Streamlines and other data as generators streamlines = (x for x in self.streamlines) - scalars = (x for x in self.colors) - properties = (x for x in self.mean_curvature_torsion) + data_per_point = {"colors": (x for x in self.colors)} + data_per_streamline = {'mean_curv': (x for x in self.mean_curvature), + 'mean_color': (x for x in self.mean_color)} - # Creating LazyTractogram from generators is not allowed as - # generators get exhausted and are not reusable unline coroutines. + # Creating LazyTractogram with generators is not allowed as + # generators get exhausted and are not reusable unlike coroutines. assert_raises(TypeError, LazyTractogram, streamlines) - assert_raises(TypeError, LazyTractogram, self.streamlines, scalars) - assert_raises(TypeError, LazyTractogram, properties_func=properties) + assert_raises(TypeError, LazyTractogram, + data_per_streamline=data_per_streamline) + assert_raises(TypeError, LazyTractogram, self.streamlines, + data_per_point=data_per_point) # Empty `LazyTractogram` tractogram = LazyTractogram() - with suppress_warnings(): - assert_equal(len(tractogram), 0) + assert_true(isiterable(tractogram)) + assert_equal(len(tractogram), 0) assert_arrays_equal(tractogram.streamlines, []) - assert_arrays_equal(tractogram.scalars, []) - assert_arrays_equal(tractogram.properties, []) - - # Check if we can iterate through the tractogram. - for streamline in tractogram: - pass + assert_equal(tractogram.data_per_point, {}) + assert_equal(tractogram.data_per_streamline, {}) - # Points, scalars and properties + # Create tractogram with streamlines and other data streamlines = lambda: (x for x in self.streamlines) - scalars = lambda: (x for x in self.colors) - properties = lambda: (x for x in self.mean_curvature_torsion) + data_per_point = {"colors": lambda: (x for x in self.colors)} + data_per_streamline = {'mean_curv': lambda: (x for x in self.mean_curvature), + 'mean_color': lambda: (x for x in self.mean_color)} - tractogram = LazyTractogram(streamlines, scalars, properties) - with suppress_warnings(): - assert_equal(len(tractogram), self.nb_streamlines) + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + + assert_true(isiterable(tractogram)) + assert_equal(len(tractogram), self.nb_streamlines) # Coroutines get re-called and creates new iterators. - assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.scalars, self.colors) - assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) - assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.scalars, self.colors) - assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) + for i in range(2): + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.data_per_streamline['mean_curv'], + self.mean_curvature) + assert_arrays_equal(tractogram.data_per_streamline['mean_color'], + self.mean_color) + assert_arrays_equal(tractogram.data_per_point['colors'], + self.colors) # Create `LazyTractogram` from a coroutine yielding 3-tuples - data = lambda: (x for x in zip(self.streamlines, self.colors, self.mean_curvature_torsion)) - - tractogram = LazyTractogram.create_from_data(data) - with suppress_warnings(): - assert_equal(len(tractogram), self.nb_streamlines) + def _data_gen(): + for d in zip(self.streamlines, self.colors, + self.mean_curvature, self.mean_color): + data_for_points = {'colors': d[1]} + data_for_streamline = {'mean_curv': d[2], + 'mean_color': d[3]} + yield TractogramItem(d[0], data_for_streamline, data_for_points) + + tractogram = LazyTractogram.create_from(_data_gen) + assert_true(isiterable(tractogram)) + assert_equal(len(tractogram), self.nb_streamlines) assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.scalars, self.colors) - assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) - - # Check if we can iterate through the tractogram. - for streamline in tractogram: - pass + assert_arrays_equal(tractogram.data_per_streamline['mean_curv'], + self.mean_curvature) + assert_arrays_equal(tractogram.data_per_streamline['mean_color'], + self.mean_color) + assert_arrays_equal(tractogram.data_per_point['colors'], + self.colors) def test_lazy_tractogram_indexing(self): streamlines = lambda: (x for x in self.streamlines) @@ -473,7 +288,8 @@ def getitem_without_properties(idx): return list(zip(self.streamlines[idx], self.colors[idx])) - tractogram = LazyTractogram(streamlines, scalars, properties, getitem_without_properties) + tractogram = LazyTractogram(streamlines, scalars, properties, + getitem_without_properties) streamlines, scalars = tractogram[0] assert_array_equal(streamlines, self.streamlines[0]) assert_array_equal(scalars, self.colors[0]) @@ -491,7 +307,7 @@ def test_lazy_tractogram_len(self): scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - with clear_and_catch_warnings(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[module_tractogram]) as w: warnings.simplefilter("always") # Always trigger warnings. # Calling `len` will create new generators each time. @@ -510,19 +326,20 @@ def test_lazy_tractogram_len(self): assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 2) - with clear_and_catch_warnings(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[module_tractogram]) as w: # Once we iterated through the tractogram, we know the length. tractogram = LazyTractogram(streamlines, scalars, properties) assert_true(tractogram.header.nb_streamlines is None) for streamline in tractogram: pass - assert_equal(tractogram.header.nb_streamlines, len(self.streamlines)) + assert_equal(tractogram.header.nb_streamlines, + len(self.streamlines)) # This should *not* produce a warning. assert_equal(len(tractogram), len(self.streamlines)) assert_equal(len(w), 0) - with clear_and_catch_warnings(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[module_tractogram]) as w: # It first checks if number of tractogram is in the header. tractogram = LazyTractogram(streamlines, scalars, properties) tractogram.header.nb_streamlines = 1234 diff --git a/nibabel/streamlines/tests/test_utils.py b/nibabel/streamlines/tests/test_utils.py index b5c67e0770..6c3bf096a6 100644 --- a/nibabel/streamlines/tests/test_utils.py +++ b/nibabel/streamlines/tests/test_utils.py @@ -1,6 +1,4 @@ import os -import unittest -import tempfile import numpy as np import nibabel as nib @@ -8,9 +6,7 @@ from numpy.testing import assert_array_equal from nose.tools import assert_equal, assert_raises, assert_true -from ..base_format import CompactList from ..utils import pop, get_affine_from_reference -from ..utils import save_compact_list, load_compact_list def test_peek(): @@ -36,26 +32,3 @@ def test_get_affine_from_reference(): # Get affine from a `SpatialImage` using by its filename. assert_array_equal(get_affine_from_reference(filename), affine) - - -def test_save_and_load_compact_list(): - - with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: - clist = CompactList() - save_compact_list(f, clist) - f.seek(0, os.SEEK_SET) - loaded_clist = load_compact_list(f) - assert_array_equal(loaded_clist._data, clist._data) - assert_array_equal(loaded_clist._offsets, clist._offsets) - assert_array_equal(loaded_clist._lengths, clist._lengths) - - with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - clist = CompactList(data) - save_compact_list(f, clist) - f.seek(0, os.SEEK_SET) - loaded_clist = load_compact_list(f) - assert_array_equal(loaded_clist._data, clist._data) - assert_array_equal(loaded_clist._offsets, clist._offsets) - assert_array_equal(loaded_clist._lengths, clist._lengths) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py new file mode 100644 index 0000000000..c7e0593065 --- /dev/null +++ b/nibabel/streamlines/tractogram.py @@ -0,0 +1,390 @@ +import numpy as np +from warnings import warn + +from nibabel.affines import apply_affine + +from .compact_list import CompactList + + +class UsageWarning(Warning): + pass + + +class TractogramItem(object): + """ Class containing information about one streamline. + + ``TractogramItem`` objects have three main properties: `streamline`, + `data_for_streamline`, and `data_for_points`. + + Parameters + ---------- + streamline : ndarray of shape (N, 3) + Points of this streamline represented as an ndarray of shape (N, 3) + where N is the number of points. + + data_for_streamline : dict + + data_for_points : dict + """ + def __init__(self, streamline, data_for_streamline, data_for_points): + self.streamline = np.asarray(streamline) + self.data_for_streamline = data_for_streamline + self.data_for_points = data_for_points + + def __iter__(self): + return iter(self.streamline) + + def __len__(self): + return len(self.streamline) + + +class Tractogram(object): + """ Class containing information about streamlines. + + Tractogram objects have three main properties: ``streamlines`` + + Parameters + ---------- + streamlines : list of ndarray of shape (Nt, 3) + Sequence of T streamlines. One streamline is an ndarray of shape + (Nt, 3) where Nt is the number of points of streamline t. + + data_per_streamline : dictionary of list of ndarray of shape (P,) + Sequence of T ndarrays of shape (P,) where T is the number of + streamlines defined by ``streamlines``, P is the number of properties + associated to each streamline. + + data_per_point : dictionary of list of ndarray of shape (Nt, M) + Sequence of T ndarrays of shape (Nt, M) where T is the number of + streamlines defined by ``streamlines``, Nt is the number of points + for a particular streamline t and M is the number of scalars + associated to each point (excluding the three coordinates). + + """ + def __init__(self, streamlines=None, + data_per_streamline=None, + data_per_point=None): + + self.streamlines = streamlines + self.data_per_streamline = data_per_streamline + self.data_per_point = data_per_point + + @property + def streamlines(self): + return self._streamlines + + @streamlines.setter + def streamlines(self, value): + self._streamlines = CompactList(value) + + @property + def data_per_streamline(self): + return self._data_per_streamline + + @data_per_streamline.setter + def data_per_streamline(self, value): + if value is None: + value = {} + + self._data_per_streamline = {} + for k, v in value.items(): + self._data_per_streamline[k] = np.asarray(v) + + @property + def data_per_point(self): + return self._data_per_point + + @data_per_point.setter + def data_per_point(self, value): + if value is None: + value = {} + + self._data_per_point = {} + for k, v in value.items(): + self._data_per_point[k] = CompactList(v) + + def __iter__(self): + for i in range(len(self.streamlines)): + yield self[i] + + def __getitem__(self, idx): + pts = self.streamlines[idx] + + new_data_per_streamline = {} + for key in self.data_per_streamline: + new_data_per_streamline[key] = self.data_per_streamline[key][idx] + + new_data_per_point = {} + for key in self.data_per_point: + new_data_per_point[key] = self.data_per_point[key][idx] + + if type(idx) is slice: + return Tractogram(pts, new_data_per_streamline, new_data_per_point) + + return TractogramItem(pts, new_data_per_streamline, new_data_per_point) + + def __len__(self): + return len(self.streamlines) + + def copy(self): + """ Returns a copy of this `Tractogram` object. """ + new_data_per_streamline = {} + for key in self.data_per_streamline: + new_data_per_streamline[key] = self.data_per_streamline[key].copy() + + new_data_per_point = {} + for key in self.data_per_point: + new_data_per_point[key] = self.data_per_point[key].copy() + + tractogram = Tractogram(self.streamlines.copy(), + new_data_per_streamline, + new_data_per_point) + return tractogram + + def apply_affine(self, affine): + """ Applies an affine transformation on the points of each streamline. + + This is performed in-place. + + Parameters + ---------- + affine : 2D array (4,4) + Transformation that will be applied on each streamline. + """ + if len(self.streamlines) == 0: + return + + BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. + for i in range(0, len(self.streamlines._data), BUFFER_SIZE): + pts = self.streamlines._data[i:i+BUFFER_SIZE] + self.streamlines._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) + + +import collections +class LazyTractogram(Tractogram): + ''' Class containing information about streamlines. + + Tractogram objects have four main properties: ``header``, ``streamlines``, + ``scalars`` and ``properties``. Tractogram objects are iterable and + produce tuple of ``streamlines``, ``scalars`` and ``properties`` for each + streamline. + + Parameters + ---------- + streamlines_func : coroutine ouputting (Nt,3) array-like (optional) + Function yielding streamlines. One streamline is + an ndarray of shape (Nt,3) where Nt is the number of points of + streamline t. + + scalars_func : coroutine ouputting (Nt,M) array-like (optional) + Function yielding scalars for a particular streamline t. The scalars + are represented as an ndarray of shape (Nt,M) where Nt is the number + of points of that streamline t and M is the number of scalars + associated to each point (excluding the three coordinates). + + properties_func : coroutine ouputting (P,) array-like (optional) + Function yielding properties for a particular streamline t. The + properties are represented as an ndarray of shape (P,) where P is + the number of properties associated to each streamline. + + getitem_func : function `idx -> 3-tuples` (optional) + Function returning a subset of the tractogram given an index or a + slice (i.e. the __getitem__ function to use). + + Notes + ----- + If provided, ``scalars`` and ``properties`` must yield the same number of + values as ``streamlines``. + ''' + + class LazyDict(collections.MutableMapping): + """ Internal dictionary with lazy evaluations. """ + + def __init__(self, *args, **kwargs): + self.store = dict() + self.update(dict(*args, **kwargs)) # Use update to set keys. + + def __getitem__(self, key): + return self.store[key]() + + def __setitem__(self, key, value): + if value is not None and not callable(value): + raise TypeError("`value` must be a coroutine or None.") + + self.store[key] = value + + def __delitem__(self, key): + del self.store[key] + + def __iter__(self): + return iter(self.store) + + def __len__(self): + return len(self.store) + + def __init__(self, streamlines=None, data_per_streamline=None, + data_per_point=None): + super(LazyTractogram, self).__init__(streamlines, data_per_streamline, + data_per_point) + self._nb_streamlines = None + self._data = None + self._getitem = None + self._affine_to_apply = np.eye(4) + + @classmethod + def create_from(cls, data_func): + ''' Creates a `LazyTractogram` from a coroutine yielding + `TractogramItem` objects. + + Parameters + ---------- + data_func : coroutine yielding `TractogramItem` objects + A function that whenever it is called starts yielding + `TractogramItem` objects that should be part of this + LazyTractogram. + + ''' + if not callable(data_func): + raise TypeError("`data_func` must be a coroutine.") + + lazy_tractogram = cls() + lazy_tractogram._data = data_func + + # Set data_per_streamline using data_func + def _gen(key): + return lambda: (t.data_for_streamline[key] for t in data_func()) + + data_per_streamline_keys = next(data_func()).data_for_streamline.keys() + for k in data_per_streamline_keys: + lazy_tractogram._data_per_streamline[k] = _gen(k) + + # Set data_per_point using data_func + def _gen(key): + return lambda: (t.data_for_points[key] for t in data_func()) + + data_per_point_keys = next(data_func()).data_for_points.keys() + for k in data_per_point_keys: + lazy_tractogram._data_per_point[k] = _gen(k) + + return lazy_tractogram + + @property + def streamlines(self): + streamlines_gen = iter([]) + if self._streamlines is not None: + streamlines_gen = self._streamlines() + elif self._data is not None: + streamlines_gen = (t.streamline for t in self._data()) + + # Check if we need to apply an affine. + if not np.all(self._affine_to_apply == np.eye(4)): + def _apply_affine(): + for s in streamlines_gen: + yield apply_affine(self._affine_to_apply, s) + + streamlines_gen = _apply_affine() + + return streamlines_gen + + @streamlines.setter + def streamlines(self, value): + if value is not None and not callable(value): + raise TypeError("`streamlines` must be a coroutine.") + + self._streamlines = value + + @property + def data_per_streamline(self): + return self._data_per_streamline + + @data_per_streamline.setter + def data_per_streamline(self, value): + if value is None: + value = {} + + self._data_per_streamline = LazyTractogram.LazyDict(value) + + @property + def data_per_point(self): + return self._data_per_point + + @data_per_point.setter + def data_per_point(self, value): + if value is None: + value = {} + + self._data_per_point = LazyTractogram.LazyDict(value) + + @property + def data(self): + if self._data is not None: + return self._data() + + def _gen_data(): + data_per_streamline_generators = {} + for k, v in self.data_per_streamline.items(): + data_per_streamline_generators[k] = iter(v) + + data_per_point_generators = {} + for k, v in self.data_per_point.items(): + data_per_point_generators[k] = iter(v) + + for s in self.streamlines: + data_for_streamline = {} + for k, v in data_per_streamline_generators.items(): + data_for_streamline[k] = next(v) + + data_for_points = {} + for k, v in data_per_point_generators.items(): + data_for_points[k] = next(v) + + yield TractogramItem(s, data_for_streamline, data_for_points) + + return _gen_data() + + def __getitem__(self, idx): + if self._getitem is None: + raise AttributeError('`LazyTractogram` does not support indexing.') + + return self._getitem(idx) + + def __iter__(self): + i = 0 + for i, tractogram_item in enumerate(self.data, start=1): + yield tractogram_item + + # Keep how many streamlines there are in this tractogram. + self._nb_streamlines = i + + def __len__(self): + # Check if we know how many streamlines there are. + if self._nb_streamlines is None: + warn("Number of streamlines will be determined manually by looping" + " through the streamlines. If you know the actual number of" + " streamlines, you might want to set it beforehand via" + " `self.header.nb_streamlines`." + " Note this will consume any generators used to create this" + " `LazyTractogram` object.", UsageWarning) + # Count the number of streamlines. + self._nb_streamlines = sum(1 for _ in self.streamlines) + + return self._nb_streamlines + + def copy(self): + """ Returns a copy of this `LazyTractogram` object. """ + tractogram = LazyTractogram(self._streamlines, + self._data_per_streamline, + self._data_per_point) + tractogram.nb_streamlines = self.nb_streamlines + tractogram._data = self._data + return tractogram + + def apply_affine(self, affine): + """ Applies an affine transformation on the streamlines. + + Parameters + ---------- + affine : 2D array (4,4) + Transformation that will be applied on each streamline. + """ + self._affine_to_apply = np.dot(affine, self._affine_to_apply) diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py new file mode 100644 index 0000000000..2ec5b69a68 --- /dev/null +++ b/nibabel/streamlines/tractogram_file.py @@ -0,0 +1,103 @@ +from abc import ABCMeta, abstractmethod, abstractproperty + +from .header import TractogramHeader + + +class abstractclassmethod(classmethod): + __isabstractmethod__ = True + + def __init__(self, callable): + callable.__isabstractmethod__ = True + super(abstractclassmethod, self).__init__(callable) + + +class TractogramFile(object): + ''' Convenience class to encapsulate tractogram file format. ''' + __metaclass__ = ABCMeta + + def __init__(self, tractogram, header=None): + self._tractogram = tractogram + self._header = TractogramHeader() if header is None else header + + @property + def tractogram(self): + return self._tractogram + + @property + def streamlines(self): + return self.tractogram.streamlines + + @property + def header(self): + return self._header + + def get_tractogram(self): + return self.tractogram + + def get_streamlines(self): + return self.streamlines + + def get_header(self): + return self.header + + @classmethod + def get_magic_number(cls): + ''' Returns streamlines file's magic number. ''' + raise NotImplementedError() + + @classmethod + def support_data_per_point(cls): + ''' Tells if this tractogram format supports saving data per point. ''' + raise NotImplementedError() + + @classmethod + def support_data_per_streamline(cls): + ''' Tells if this tractogram format supports saving data per streamline. ''' + raise NotImplementedError() + + @classmethod + def is_correct_format(cls, fileobj): + ''' Checks if the file has the right streamlines file format. + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header). + Returns + ------- + is_correct_format : boolean + Returns True if `fileobj` is in the right streamlines file format. + ''' + raise NotImplementedError() + + @abstractclassmethod + def load(cls, fileobj, lazy_load=True): + ''' Loads streamlines from a file-like object. + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header). + lazy_load : boolean (optional) + Load streamlines in a lazy manner i.e. they will not be kept + in memory. For postprocessing speed, turn off this option. + Returns + ------- + tractogram_file : ``TractogramFile`` object + Returns an object containing tractogram data and header + information. + ''' + raise NotImplementedError() + + @abstractmethod + def save(self, fileobj): + ''' Saves streamlines to a file-like object. + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + opened and ready to write. + ''' + raise NotImplementedError() diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index b93a65f5d9..6d1fc66397 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -3,9 +3,10 @@ # Documentation available here: # http://www.trackvis.org/docs/?subsect=fileformat -import struct import os +import struct import warnings +import itertools import numpy as np import nibabel as nib @@ -13,12 +14,13 @@ from nibabel.openers import Opener from nibabel.volumeutils import (native_code, swapped_code) -from nibabel.streamlines.base_format import TractogramFile -from nibabel.streamlines.base_format import DataError, HeaderError, HeaderWarning -from nibabel.streamlines.base_format import Tractogram, LazyTractogram -from nibabel.streamlines.header import Field +from .compact_list import CompactList +from .tractogram_file import TractogramFile +from .base_format import DataError, HeaderError, HeaderWarning +from .tractogram import Tractogram, LazyTractogram +from .header import Field -from nibabel.streamlines.utils import get_affine_from_reference +from .utils import get_affine_from_reference # Definition of trackvis header structure. # See http://www.trackvis.org/docs/?subsect=fileformat @@ -298,9 +300,6 @@ def write(self, tractogram): self.file.write(self.header.tostring()) -import itertools -from nibabel.streamlines.base_format import CompactList - def create_compactlist_from_generator(gen): BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 7fbfab5106..7bbbe1ef8d 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -33,22 +33,3 @@ def pop(iterable): "Returns the next item from the iterable else None" value = list(itertools.islice(iterable, 1)) return value[0] if len(value) > 0 else None - - -def save_compact_list(filename, compact_list): - """ Saves a `CompactList` object to a .npz file. """ - np.savez(filename, - data=compact_list._data, - offsets=compact_list._offsets, - lengths=compact_list._lengths) - - -def load_compact_list(filename): - """ Loads a `CompactList` object from a .npz file. """ - from .base_format import CompactList - content = np.load(filename) - compact_list = CompactList() - compact_list._data = content["data"] - compact_list._offsets = content["offsets"].tolist() - compact_list._lengths = content["lengths"].tolist() - return compact_list From 5b65636ebd82090b4d5f91f4a26fb997d41c3635 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 17 Nov 2015 17:33:55 -0500 Subject: [PATCH 28/54] Save scalars and properties name when using the TRK file format --- nibabel/streamlines/tractogram.py | 2 +- nibabel/streamlines/trk.py | 119 ++++++++++++++++++++++-------- 2 files changed, 89 insertions(+), 32 deletions(-) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index c7e0593065..f7907ca888 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -1,4 +1,5 @@ import numpy as np +import collections from warnings import warn from nibabel.affines import apply_affine @@ -160,7 +161,6 @@ def apply_affine(self, affine): self.streamlines._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) -import collections class LazyTractogram(Tractogram): ''' Class containing information about streamlines. diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 6d1fc66397..b8c325aca0 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -17,11 +17,14 @@ from .compact_list import CompactList from .tractogram_file import TractogramFile from .base_format import DataError, HeaderError, HeaderWarning -from .tractogram import Tractogram, LazyTractogram +from .tractogram import TractogramItem, Tractogram, LazyTractogram from .header import Field from .utils import get_affine_from_reference +MAX_NB_NAMED_SCALARS_PER_POINT = 10 +MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE = 10 + # Definition of trackvis header structure. # See http://www.trackvis.org/docs/?subsect=fileformat # See http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html @@ -30,9 +33,9 @@ (Field.VOXEL_SIZES, 'f4', 3), (Field.ORIGIN, 'f4', 3), (Field.NB_SCALARS_PER_POINT, 'h'), - ('scalar_name', 'S20', 10), + ('scalar_name', 'S20', MAX_NB_NAMED_SCALARS_PER_POINT), (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), - ('property_name', 'S20', 10), + ('property_name', 'S20', MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), ('reserved', 'S508'), (Field.VOXEL_ORDER, 'S4'), ('pad2', 'S4'), @@ -58,9 +61,9 @@ (Field.VOXEL_SIZES, 'f4', 3), (Field.ORIGIN, 'f4', 3), (Field.NB_SCALARS_PER_POINT, 'h'), - ('scalar_name', 'S20', 10), + ('scalar_name', 'S20', MAX_NB_NAMED_SCALARS_PER_POINT), (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), - ('property_name', 'S20', 10), + ('property_name', 'S20', MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), (Field.VOXEL_TO_RASMM, 'f4', (4, 4)), # new field for version 2 ('reserved', 'S444'), (Field.VOXEL_ORDER, 'S4'), @@ -264,9 +267,9 @@ def write(self, tractogram): raise DataError("Missing scalars for some points!") points = np.asarray(t.streamline, dtype=f4_dtype) - keys = sorted(t.data_for_points.keys()) + keys = sorted(t.data_for_points.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] scalars = np.asarray([t.data_for_points[k] for k in keys], dtype=f4_dtype).reshape((len(points), -1)) - keys = sorted(t.data_for_streamline.keys()) + keys = sorted(t.data_for_streamline.keys())[:MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE] properties = np.asarray([t.data_for_streamline[k] for k in keys], dtype=f4_dtype).flatten() data = struct.pack(i4_dtype.str[:-1], len(points)) @@ -524,36 +527,59 @@ def _apply_transform(trk_reader): trk_reader yield pts, scals, props - data = lambda: _apply_transform(trk_reader) - tractogram = LazyTractogram.create_from_data(data) + def _read(): + for pts, scals, props in trk_reader: + # TODO + data_for_streamline = {} + data_for_points = {} + yield TractogramItem(pts, data_for_streamline, data_for_points) - # Overwrite scalars and properties if there is none - if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: - tractogram.scalars = lambda: [] - if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: - tractogram.properties = lambda: [] + tractogram = LazyTractogram.create_from(_read) else: streamlines, scalars, properties = create_compactlist_from_generator(trk_reader) tractogram = Tractogram(streamlines) if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: - if len(trk_reader.header['scalar_name'][0]) > 0: - for i in range(trk_reader.header[Field.NB_SCALARS_PER_POINT]): - clist = CompactList() - clist._data = scalars._data[:, i] - clist._offsets = scalars._offsets - clist._lengths = scalars._lengths - tractogram.data_per_point[trk_reader.header['scalar_name'][i]] = clist - else: - tractogram.data_per_point['scalars'] = scalars + cpt = 0 + for scalar_name in trk_reader.header['scalar_name']: + if len(scalar_name) == 0: + continue + + nb_scalars = np.fromstring(scalar_name[-1], np.int8) + + clist = CompactList() + clist._data = scalars._data[:, cpt:cpt+nb_scalars] + clist._offsets = scalars._offsets + clist._lengths = scalars._lengths + + scalar_name = scalar_name.split('\x00')[0] + tractogram.data_per_point[scalar_name] = clist + cpt += nb_scalars + + if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: + #tractogram.data_per_point['scalars'] = scalars + clist = CompactList() + clist._data = scalars._data[:, cpt:] + clist._offsets = scalars._offsets + clist._lengths = scalars._lengths + tractogram.data_per_point['scalars'] = clist if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: - if len(trk_reader.header['property_name'][0]) > 0: - for i in range(trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]): - tractogram.data_per_streamline[trk_reader.header['property_name'][i]] = properties[:, i] - else: - tractogram.data_per_streamline['properties'] = properties + cpt = 0 + for property_name in trk_reader.header['property_name']: + if len(property_name) == 0: + continue + + nb_properties = np.fromstring(property_name[-1], np.int8) + property_name = property_name.split('\x00')[0] + tractogram.data_per_streamline[property_name] = properties[:, cpt:cpt+nb_properties] + cpt += nb_properties + + if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: + #tractogram.data_per_streamline['properties'] = properties + nb_properties = np.fromstring(property_name[-1], np.int8) + tractogram.data_per_streamline['properties'] = properties[:, cpt:] # Bring tractogram to RAS+ and mm space tractogram.apply_affine(affine) @@ -578,9 +604,40 @@ def save(self, fileobj): pointing to TRK file (and ready to read from the beginning of the TRK header data). ''' - # Update header using the tractogram. - self.header.nb_scalars_per_point = sum(map(lambda e: len(e[0]), self.tractogram.data_per_point.values())) - self.header.nb_properties_per_streamline = sum(map(lambda e: len(e[0]), self.tractogram.data_per_streamline.values())) + # Compute how many properties per streamline the tractogram has. + self.header.nb_properties_per_streamline = 0 + self.header.extra['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') + data_for_streamline = self.tractogram[0].data_for_streamline + for i, k in enumerate(sorted(data_for_streamline.keys())): + if i >= MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: + warnings.warn(("Can only store {0} named properties: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) + + if len(k) > 19: + warnings.warn(("Property name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) + + v = data_for_streamline[k] + self.header.nb_properties_per_streamline += v.shape[0] + + property_name = k[:19].ljust(19, '\x00') + np.array(v.shape[0], dtype=np.int8).tostring() + self.header.extra['property_name'][i] = property_name + + # Compute how many scalars per point the tractogram has. + self.header.nb_scalars_per_point = 0 + self.header.extra['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') + data_for_points = self.tractogram[0].data_for_points + for i, k in enumerate(sorted(data_for_points.keys())): + if i >= MAX_NB_NAMED_SCALARS_PER_POINT: + warnings.warn(("Can only store {0} named scalars: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) + + if len(k) > 19: + warnings.warn(("Scalar name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) + + v = data_for_points[k] + self.header.nb_scalars_per_point += v.shape[1] + + scalar_name = k[:19].ljust(19, '\x00') + np.array(v.shape[1], dtype=np.int8).tostring() + self.header.extra['scalar_name'][i] = scalar_name + trk_writer = TrkWriter(fileobj, self.header) trk_writer.write(self.tractogram) From 6f7690edf7047230548759ab615676c669626091 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 17 Nov 2015 22:47:48 -0500 Subject: [PATCH 29/54] BF: Extend on empty CompactList is now allowed --- nibabel/streamlines/compact_list.py | 4 ++++ nibabel/streamlines/tests/test_compact_list.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 87671bb5c2..8ead3ed50d 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -101,6 +101,10 @@ def extend(self, elements): shape except for the first dimension. """ + if self._data is None: + elem = np.asarray(elements[0]) + self._data = np.zeros((0, elem.shape[1]), dtype=elem.dtype) + if isinstance(elements, CompactList): self._data = np.concatenate([self._data, elements._data], axis=0) lengths = elements._lengths diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 2211d9d303..90783a29b0 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -173,6 +173,14 @@ def test_compactlist_extend(self): assert_equal(clist._lengths[-len(new_data):], lengths) assert_array_equal(clist._data[-sum(lengths):], new_data._data) + # Test extending an empty CompactList + clist = CompactList() + clist.extend(new_data) + assert_equal(len(clist), len(new_data)) + assert_array_equal(clist._offsets, new_data._offsets) + assert_array_equal(clist._lengths, new_data._lengths) + assert_array_equal(clist._data, new_data._data) + def test_compactlist_getitem(self): # Get one item for i, e in enumerate(self.clist): From 999bba47568bb93bee5d1daf35f7c5b8122612ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 18 Nov 2015 10:30:05 -0500 Subject: [PATCH 30/54] BF: Support creating TractogramHeader from dict --- nibabel/streamlines/header.py | 22 +++++++++++++++++++++- nibabel/streamlines/tests/test_trk.py | 13 +++++++++++-- nibabel/streamlines/tractogram_file.py | 3 ++- nibabel/streamlines/trk.py | 3 +-- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 2298e395b2..4b7ea30605 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -25,13 +25,33 @@ class Field: class TractogramHeader(object): - def __init__(self): + def __init__(self, hdr=None): self._nb_streamlines = None self._nb_scalars_per_point = None self._nb_properties_per_streamline = None self._to_world_space = np.eye(4) self.extra = OrderedDict() + if type(hdr) is dict: + if Field.NB_POINTS in hdr: + self.nb_streamlines = hdr[Field.NB_POINTS] + + if Field.NB_SCALARS_PER_POINT in hdr: + self.nb_scalars_per_point = hdr[Field.NB_SCALARS_PER_POINT] + + if Field.NB_PROPERTIES_PER_STREAMLINE in hdr: + self.nb_properties_per_streamline = hdr[Field.NB_PROPERTIES_PER_STREAMLINE] + + if Field.VOXEL_TO_RASMM in hdr: + self.to_world_space = hdr[Field.VOXEL_TO_RASMM] + + elif type(hdr) is TractogramHeader: + self.nb_streamlines = hdr.nb_streamlines + self.nb_scalars_per_point = hdr.nb_scalars_per_point + self.nb_properties_per_streamline = hdr.nb_properties_per_streamline + self.to_world_space = hdr.to_world_space + self.extra = copy.deepcopy(hdr.extra) + @property def to_world_space(self): return self._to_world_space diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 1b1c2ab7a6..3322140097 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -9,8 +9,8 @@ from nose.tools import assert_equal, assert_raises, assert_true from .. import base_format -from ..base_format import Tractogram, LazyTractogram -from ..base_format import DataError, HeaderError, HeaderWarning, UsageWarning +from ..tractogram import Tractogram, LazyTractogram +from ..base_format import DataError, HeaderError, HeaderWarning#, UsageWarning #from .. import trk from ..trk import TrkFile @@ -254,6 +254,15 @@ def test_write_erroneous_file(self): trk = TrkFile(tractogram, ref=self.affine) assert_raises(IndexError, trk.save, BytesIO()) + def test_load_write_simple_file(self): + trk = TrkFile.load(self.simple_trk_filename, lazy_load=False) + trk_file = BytesIO() + trk.save(trk_file) + + # trk = TrkFile.load(self.simple_trk_filename, lazy_load=True) + # check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) + + # def test_write_file_lazy_tractogram(self): # streamlines = lambda: (point for point in self.streamlines) # scalars = lambda: (scalar for scalar in self.colors) diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index 2ec5b69a68..d3f54c2a84 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -17,7 +17,8 @@ class TractogramFile(object): def __init__(self, tractogram, header=None): self._tractogram = tractogram - self._header = TractogramHeader() if header is None else header + #self._header = TractogramHeader() if header is None else header + self._header = TractogramHeader(header) @property def tractogram(self): diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index b8c325aca0..92b5466ef0 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -578,7 +578,6 @@ def _read(): if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: #tractogram.data_per_streamline['properties'] = properties - nb_properties = np.fromstring(property_name[-1], np.int8) tractogram.data_per_streamline['properties'] = properties[:, cpt:] # Bring tractogram to RAS+ and mm space @@ -592,7 +591,7 @@ def _read(): #if tractogram.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: # raise HeaderError("'nb_properties_per_streamline' does not match.") - return cls(tractogram, ref=affine, header=trk_reader.header) + return cls(tractogram, header=trk_reader.header, ref=affine) def save(self, fileobj): ''' Saves tractogram to a file-like object using TRK format. From 396e02e7cd35f43ba7483a835dcb55f64b1967a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 18 Nov 2015 15:24:03 -0500 Subject: [PATCH 31/54] BF: Fixed creating TractogramHeader from another TractogramHeader when its nb_streamlines was None --- nibabel/streamlines/header.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 4b7ea30605..3fac6952bd 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -46,10 +46,10 @@ def __init__(self, hdr=None): self.to_world_space = hdr[Field.VOXEL_TO_RASMM] elif type(hdr) is TractogramHeader: - self.nb_streamlines = hdr.nb_streamlines - self.nb_scalars_per_point = hdr.nb_scalars_per_point - self.nb_properties_per_streamline = hdr.nb_properties_per_streamline - self.to_world_space = hdr.to_world_space + self._nb_streamlines = hdr._nb_streamlines + self._nb_scalars_per_point = hdr._nb_scalars_per_point + self._nb_properties_per_streamline = hdr._nb_properties_per_streamline + self._to_world_space = hdr._to_world_space self.extra = copy.deepcopy(hdr.extra) @property From 84112dba260a70be8fe45163b42271698706cc33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 19 Nov 2015 01:01:01 -0500 Subject: [PATCH 32/54] BF: Not all property and scalar name were save in the TRK header --- nibabel/streamlines/tests/data/complex.trk | Bin 1228 -> 1228 bytes nibabel/streamlines/tests/test_streamlines.py | 6 +++--- nibabel/streamlines/tests/test_trk.py | 3 +++ nibabel/streamlines/trk.py | 11 ++++++++++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/nibabel/streamlines/tests/data/complex.trk b/nibabel/streamlines/tests/data/complex.trk index 9bfbe5ea60917d6f98920b3a08e460be5fb6732d..e6155154c43fa34577517fefce8f1ac14b9a4ce7 100644 GIT binary patch delta 64 zcmX@Zd4_X>8hdeaVoqXF@kEJ(VoVGLMfnA(MJ1W3#SAdOG|_O|+C=Nk69YsiGc#@nQRf(KfLLU+0Mio2i2-bzw=hdE0suGt4gvrG diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 5dabad1815..e27e7cc42a 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -12,9 +12,9 @@ from nibabel.testing import assert_arrays_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true, assert_false -from ..base_format import Tractogram, LazyTractogram, TractogramFile -from ..base_format import HeaderError, UsageWarning -from ..header import Field +from ..tractogram import Tractogram, LazyTractogram +from ..tractogram_file import TractogramFile +from ..tractogram import UsageWarning from .. import trk DATA_PATH = pjoin(os.path.dirname(__file__), 'data') diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 3322140097..ed7317db5a 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -206,6 +206,9 @@ def test_write_complex_file(self): assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) + #raw = trk_file.read() + #raw = raw[:38] + "\x00"*200 + raw[38+200:] # Overwrite the sclar name section + #raw = raw[:240] + "\x00"*200 + raw[240+200:] # Overwrite the property name section assert_equal(open(self.complex_trk_filename, 'rb').read(), trk_file.read()) def test_write_erroneous_file(self): diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 92b5466ef0..6c97c1b1fd 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -218,11 +218,20 @@ def create_empty_header(cls): def __init__(self, fileobj, header): self.header = self.create_empty_header() - # Override hdr's fields by those contain in `header`. + # Override hdr's fields by those contained in `header`. for k, v in header.extra.items(): if k in header_2_dtype.fields.keys(): self.header[k] = v + # TODO: Fix that ugly patch. + # Because the assignment operator on ndarray of string only copy the + # first entry, we have to do it explicitly! + if "property_name" in header.extra: + self.header["property_name"][:] = header.extra["property_name"][:] + + if "scalar_name" in header.extra: + self.header["scalar_name"][:] = header.extra["scalar_name"][:] + self.header[Field.NB_STREAMLINES] = 0 if header.nb_streamlines is not None: self.header[Field.NB_STREAMLINES] = header.nb_streamlines From bfdeb6bd492a1cdb20e425d8ece468db5d41886a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 19 Nov 2015 14:51:13 -0500 Subject: [PATCH 33/54] ENH: CompactList support advance indexing with ndarray of data type bool --- nibabel/streamlines/compact_list.py | 47 +++++++++++-------- .../streamlines/tests/test_compact_list.py | 13 +++++ 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 8ead3ed50d..fb98536bab 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -154,18 +154,27 @@ def __getitem__(self, idx): return self._data[start:start+self._lengths[idx]] elif type(idx) is slice: - compact_list = CompactList() - compact_list._data = self._data - compact_list._offsets = self._offsets[idx] - compact_list._lengths = self._lengths[idx] - return compact_list + clist = CompactList() + clist._data = self._data + clist._offsets = self._offsets[idx] + clist._lengths = self._lengths[idx] + return clist elif type(idx) is list: - compact_list = CompactList() - compact_list._data = self._data - compact_list._offsets = [self._offsets[i] for i in idx] - compact_list._lengths = [self._lengths[i] for i in idx] - return compact_list + clist = CompactList() + clist._data = self._data + clist._offsets = [self._offsets[i] for i in idx] + clist._lengths = [self._lengths[i] for i in idx] + return clist + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + clist = CompactList() + clist._data = self._data + clist._offsets = [self._offsets[i] + for i, take_it in enumerate(idx) if take_it] + clist._lengths = [self._lengths[i] + for i, take_it in enumerate(idx) if take_it] + return clist raise TypeError("Index must be a int or a slice! Not " + str(type(idx))) @@ -184,19 +193,19 @@ def __repr__(self): return repr(list(self)) -def save_compact_list(filename, compact_list): +def save_compact_list(filename, clist): """ Saves a `CompactList` object to a .npz file. """ np.savez(filename, - data=compact_list._data, - offsets=compact_list._offsets, - lengths=compact_list._lengths) + data=clist._data, + offsets=clist._offsets, + lengths=clist._lengths) def load_compact_list(filename): """ Loads a `CompactList` object from a .npz file. """ content = np.load(filename) - compact_list = CompactList() - compact_list._data = content["data"] - compact_list._offsets = content["offsets"].tolist() - compact_list._lengths = content["lengths"].tolist() - return compact_list + clist = CompactList() + clist._data = content["data"] + clist._offsets = content["offsets"].tolist() + clist._lengths = content["lengths"].tolist() + return clist diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 90783a29b0..71dac112b7 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -206,6 +206,19 @@ def test_compactlist_getitem(self): for i, e in enumerate(clist_view): assert_array_equal(e, self.clist[i*2]) + # Use advance indexing with ndarray of data type bool. + idx = np.array([False, True, True, False, True]) + clist_view = self.clist[idx] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_array_equal(clist_view._offsets, + np.asarray(self.clist._offsets)[idx]) + assert_array_equal(clist_view._lengths, + np.asarray(self.clist._lengths)[idx]) + assert_array_equal(clist_view[0], self.clist[1]) + assert_array_equal(clist_view[1], self.clist[2]) + assert_array_equal(clist_view[2], self.clist[4]) + def test_save_and_load_compact_list(): From 5857c43a35c68ebe1c2844cb6c8f484316007ca2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 19 Nov 2015 14:59:59 -0500 Subject: [PATCH 34/54] ENH: Tractogram support advance indexing with ndarray of data type bool or integer --- nibabel/streamlines/tractogram.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index f7907ca888..c4ed299dc5 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -111,35 +111,35 @@ def __iter__(self): def __getitem__(self, idx): pts = self.streamlines[idx] - new_data_per_streamline = {} + data_per_streamline = {} for key in self.data_per_streamline: - new_data_per_streamline[key] = self.data_per_streamline[key][idx] + data_per_streamline[key] = self.data_per_streamline[key][idx] - new_data_per_point = {} + data_per_point = {} for key in self.data_per_point: - new_data_per_point[key] = self.data_per_point[key][idx] + data_per_point[key] = self.data_per_point[key][idx] - if type(idx) is slice: - return Tractogram(pts, new_data_per_streamline, new_data_per_point) + if isinstance(idx, int) or isinstance(idx, np.integer): + return TractogramItem(pts, data_per_streamline, data_per_point) - return TractogramItem(pts, new_data_per_streamline, new_data_per_point) + return Tractogram(pts, data_per_streamline, data_per_point) def __len__(self): return len(self.streamlines) def copy(self): """ Returns a copy of this `Tractogram` object. """ - new_data_per_streamline = {} + data_per_streamline = {} for key in self.data_per_streamline: - new_data_per_streamline[key] = self.data_per_streamline[key].copy() + data_per_streamline[key] = self.data_per_streamline[key].copy() - new_data_per_point = {} + data_per_point = {} for key in self.data_per_point: - new_data_per_point[key] = self.data_per_point[key].copy() + data_per_point[key] = self.data_per_point[key].copy() tractogram = Tractogram(self.streamlines.copy(), - new_data_per_streamline, - new_data_per_point) + data_per_streamline, + data_per_point) return tractogram def apply_affine(self, affine): From b9b844dd5f0531a5e19576b67639881307234178 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 Nov 2015 01:52:21 -0500 Subject: [PATCH 35/54] Added support for voxel order other than RAS in TRK file format --- nibabel/streamlines/tests/data/complex.trk | Bin 1228 -> 1296 bytes nibabel/streamlines/tests/test_trk.py | 117 ++++++++--- nibabel/streamlines/tractogram_file.py | 3 +- nibabel/streamlines/trk.py | 234 +++++++++++++-------- 4 files changed, 238 insertions(+), 116 deletions(-) diff --git a/nibabel/streamlines/tests/data/complex.trk b/nibabel/streamlines/tests/data/complex.trk index e6155154c43fa34577517fefce8f1ac14b9a4ce7..8aa099becd3bf9436d6cd0cc4ded087c3e794f8b 100644 GIT binary patch literal 1296 zcmWFua&-1)U<5-3h6Z~CW`F}0hUEO5{GwvG0EoeymWaX!aTqZ~2AKdWLvCtfUOcLI zm?2`NMP-R4rA4V=Co_V@N`Riu%+G^*VvzX`6j$hg5;1hMM)`v^1cDrciEc&eX>8WanLs@W4h#$v zfOrECg8)SC?3pt_IRxL9i`}7GgwKH$B8X-lTpdsx-8_igteKKPxeE>q3?G1)!x1D3 u0S)#LISHVc1`u1I$Qc_O1NC_T#Ug+>14Rz#XP_JmG}zm6xdYwk=Kuh;XGW#~ delta 157 zcmbQhb%t|-iWoCPadKi#Vo@;z5@4Rp$h1X_iJ_n= MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: + warnings.warn(("Can only store {0} named properties: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) + + if len(k) > 19: + warnings.warn(("Property name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) + + v = data_for_streamline[k] + property_name = k[:19].ljust(19, '\x00') + np.array(v.shape[0], dtype=np.int8).tostring() + self.header['property_name'][i] = property_name + + # Update the 'scalar_name' field using 'data_per_point' of the tractogram. + data_for_points = tractogram[0].data_for_points + data_for_points_keys = sorted(data_for_points.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] + self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') + for i, k in enumerate(data_for_points_keys): + if i >= MAX_NB_NAMED_SCALARS_PER_POINT: + warnings.warn(("Can only store {0} named scalars: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) + + if len(k) > 19: + warnings.warn(("Scalar name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) + + v = data_for_points[k] + scalar_name = k[:19].ljust(19, '\x00') + np.array(v.shape[1], dtype=np.int8).tostring() + self.header['scalar_name'][i] = scalar_name + + + # `Tractogram` streamlines are in RAS+ and mm space, we will compute + # the affine matrix that will bring them back to 'voxelmm' as required + # by the TRK format. + affine = np.eye(4) # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas streamlines passed in parameters assume (0,0,0) - # to be the center of the voxel. Thus, streamlines are shifted of - # half a voxel. - affine[:-1, -1] += np.array(self.header[Field.VOXEL_SIZES])/2. + # voxel whereas `Tractogram` streamlines assume (0,0,0) is the + # center of the voxel. Thus, streamlines are shifted of half a voxel. + offset = np.eye(4) + offset[:-1, -1] += np.array(self.header[Field.VOXEL_SIZES])/2. + affine = np.dot(offset, affine) + + + # Applied the inverse of the affine found in the TRK header. + # rasmm -> voxel + affine = np.dot(np.linalg.inv(self.header[Field.VOXEL_TO_RASMM]), affine) + + # If the voxel order implied by the affine does not match the voxel + # order in the TRK header, change the orientation. + # voxel (affine) -> voxel (header) + header_ornt = self.header[Field.VOXEL_ORDER] + affine_ornt = "".join(nib.orientations.aff2axcodes(self.header[Field.VOXEL_TO_RASMM])) + header_ornt = nib.orientations.axcodes2ornt(header_ornt) + affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) + ornt = nib.orientations.ornt_transform(affine_ornt, header_ornt) + M = nib.orientations.inv_ornt_aff(ornt, self.header[Field.DIMENSIONS]) + affine = np.dot(M, affine) - tractogram.apply_affine(affine) + # Finally send the streamlines in mm space. + # voxel -> voxelmm + scale = np.eye(4) + scale[range(3), range(3)] *= self.header[Field.VOXEL_SIZES] + affine = np.dot(scale, affine) + + # The TRK format uses float32 as the data type for points. + affine = affine.astype(np.float32) for t in tractogram: if any((len(d) != len(t.streamline) for d in t.data_for_points.values())): raise DataError("Missing scalars for some points!") - points = np.asarray(t.streamline, dtype=f4_dtype) - keys = sorted(t.data_for_points.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] - scalars = np.asarray([t.data_for_points[k] for k in keys], dtype=f4_dtype).reshape((len(points), -1)) - keys = sorted(t.data_for_streamline.keys())[:MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE] - properties = np.asarray([t.data_for_streamline[k] for k in keys], dtype=f4_dtype).flatten() + points = apply_affine(affine, np.asarray(t.streamline, dtype=f4_dtype)) + scalars = [np.asarray(t.data_for_points[k], dtype=f4_dtype) for k in data_for_points_keys] + scalars = np.concatenate([np.ndarray((len(points), 0), dtype=f4_dtype)] + scalars, axis=1) + properties = [np.asarray(t.data_for_streamline[k], dtype=f4_dtype) for k in data_for_streamline_keys] + properties = np.concatenate([np.array([], dtype=f4_dtype)] + properties) data = struct.pack(i4_dtype.str[:-1], len(points)) - data += np.concatenate((points, scalars), axis=1).tostring() + data += np.concatenate([points, scalars], axis=1).tostring() data += properties.tostring() self.file.write(data) @@ -291,8 +383,7 @@ def write(self, tractogram): self.nb_scalars += scalars.size self.nb_properties += len(properties) - # Either correct or warn if header and data are incoherent. - #TODO: add a warn option as a function parameter + # Use those values to update the header. nb_scalars_per_point = self.nb_scalars / self.nb_points nb_properties_per_streamline = self.nb_properties / self.nb_streamlines @@ -429,8 +520,12 @@ def __init__(self, tractogram, header=None, ref=np.eye(4)): Streamlines of the tractogram are assumed to be in *RAS+* and *mm* space where coordinate (0,0,0) refers to the center of the voxel. """ + if header is None: + header_rec = TrkWriter.create_empty_header() + header = dict(zip(header_rec.dtype.names, header_rec)) + super(TrkFile, self).__init__(tractogram, header) - self._affine = get_affine_from_reference(ref) + #self._affine = get_affine_from_reference(ref) @classmethod def get_magic_number(cls): @@ -498,13 +593,20 @@ def load(cls, fileobj, lazy_load=False): ''' trk_reader = TrkReader(fileobj) - # TRK's streamlines are in 'voxelmm' space, we send them to rasmm. - # First send them to voxel space. + # TRK's streamlines are in 'voxelmm' space, we will compute the + # affine matrix that will bring them back to RAS+ and mm space. affine = np.eye(4) - affine[range(3), range(3)] /= trk_reader.header[Field.VOXEL_SIZES] - # If voxel order implied from the affine does not match the voxel - # order save in the TRK header, change the orientation. + # The affine matrix found in the TRK header requires the points to be + # in the voxel space. + # voxelmm -> voxel + scale = np.eye(4) + scale[range(3), range(3)] /= trk_reader.header[Field.VOXEL_SIZES] + affine = np.dot(scale, affine) + + # If the voxel order implied by the affine does not match the voxel + # order in the TRK header, change the orientation. + # voxel (header) -> voxel (affine) header_ornt = trk_reader.header[Field.VOXEL_ORDER] affine_ornt = "".join(nib.orientations.aff2axcodes(trk_reader.header[Field.VOXEL_TO_RASMM])) header_ornt = nib.orientations.axcodes2ornt(header_ornt) @@ -513,14 +615,16 @@ def load(cls, fileobj, lazy_load=False): M = nib.orientations.inv_ornt_aff(ornt, trk_reader.header[Field.DIMENSIONS]) affine = np.dot(M, affine) - # Applied the affine going from voxel space to rasmm. + # Applied the affine found in the TRK header. + # voxel -> rasmm affine = np.dot(trk_reader.header[Field.VOXEL_TO_RASMM], affine) - # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas streamlines returned assume (0,0,0) to be the - # center of the voxel. Thus, streamlines are shifted of half - #a voxel. - affine[:-1, -1] -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. + # TrackVis considers coordinate (0,0,0) to be the corner of the voxel + # whereas streamlines returned assume (0,0,0) to be the center of the + # voxel. Thus, streamlines are shifted of half a voxel. + offset = np.eye(4) + offset[:-1, -1] -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. + affine = np.dot(offset, affine) if lazy_load: # TODO when LazyTractogram has been refactored. @@ -590,7 +694,7 @@ def _read(): tractogram.data_per_streamline['properties'] = properties[:, cpt:] # Bring tractogram to RAS+ and mm space - tractogram.apply_affine(affine) + tractogram.apply_affine(affine.astype(np.float32)) ## Perform some integrity checks #if tractogram.header.voxel_sizes != trk_reader.header[Field.VOXEL_SIZES]: @@ -612,40 +716,6 @@ def save(self, fileobj): pointing to TRK file (and ready to read from the beginning of the TRK header data). ''' - # Compute how many properties per streamline the tractogram has. - self.header.nb_properties_per_streamline = 0 - self.header.extra['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') - data_for_streamline = self.tractogram[0].data_for_streamline - for i, k in enumerate(sorted(data_for_streamline.keys())): - if i >= MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: - warnings.warn(("Can only store {0} named properties: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) - - if len(k) > 19: - warnings.warn(("Property name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) - - v = data_for_streamline[k] - self.header.nb_properties_per_streamline += v.shape[0] - - property_name = k[:19].ljust(19, '\x00') + np.array(v.shape[0], dtype=np.int8).tostring() - self.header.extra['property_name'][i] = property_name - - # Compute how many scalars per point the tractogram has. - self.header.nb_scalars_per_point = 0 - self.header.extra['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') - data_for_points = self.tractogram[0].data_for_points - for i, k in enumerate(sorted(data_for_points.keys())): - if i >= MAX_NB_NAMED_SCALARS_PER_POINT: - warnings.warn(("Can only store {0} named scalars: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) - - if len(k) > 19: - warnings.warn(("Scalar name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) - - v = data_for_points[k] - self.header.nb_scalars_per_point += v.shape[1] - - scalar_name = k[:19].ljust(19, '\x00') + np.array(v.shape[1], dtype=np.int8).tostring() - self.header.extra['scalar_name'][i] = scalar_name - trk_writer = TrkWriter(fileobj, self.header) trk_writer.write(self.tractogram) From cf6de5a89950395ced03e8db013f8cc43baf458b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 Nov 2015 02:21:35 -0500 Subject: [PATCH 36/54] Fixed LazyTractogram --- nibabel/streamlines/tests/test_tractogram.py | 87 +++++++++----------- nibabel/streamlines/tractogram.py | 6 +- 2 files changed, 39 insertions(+), 54 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index f106e5e08b..fe7b414ef6 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -203,6 +203,10 @@ def setUp(self): self.nb_streamlines = len(self.streamlines) + self.colors_func = lambda: (x for x in self.colors) + self.mean_curvature_func = lambda: (x for x in self.mean_curvature) + self.mean_color_func = lambda: (x for x in self.mean_color) + def test_lazy_tractogram_creation(self): # To create tractogram from arrays use `Tractogram`. assert_raises(TypeError, LazyTractogram, self.streamlines) @@ -231,9 +235,9 @@ def test_lazy_tractogram_creation(self): # Create tractogram with streamlines and other data streamlines = lambda: (x for x in self.streamlines) - data_per_point = {"colors": lambda: (x for x in self.colors)} - data_per_streamline = {'mean_curv': lambda: (x for x in self.mean_curvature), - 'mean_color': lambda: (x for x in self.mean_color)} + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} tractogram = LazyTractogram(streamlines, data_per_streamline=data_per_streamline, @@ -274,75 +278,60 @@ def _data_gen(): def test_lazy_tractogram_indexing(self): streamlines = lambda: (x for x in self.streamlines) - scalars = lambda: (x for x in self.colors) - properties = lambda: (x for x in self.mean_curvature_torsion) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} # By default, `LazyTractogram` object does not support indexing. - tractogram = LazyTractogram(streamlines, scalars, properties) + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) assert_raises(AttributeError, tractogram.__getitem__, 0) - # Create a `LazyTractogram` object with indexing support. - def getitem_without_properties(idx): - if isinstance(idx, int) or isinstance(idx, np.integer): - return self.streamlines[idx], self.colors[idx] - - return list(zip(self.streamlines[idx], self.colors[idx])) - - tractogram = LazyTractogram(streamlines, scalars, properties, - getitem_without_properties) - streamlines, scalars = tractogram[0] - assert_array_equal(streamlines, self.streamlines[0]) - assert_array_equal(scalars, self.colors[0]) - - streamlines, scalars = zip(*tractogram[::-1]) - assert_arrays_equal(streamlines, self.streamlines[::-1]) - assert_arrays_equal(scalars, self.colors[::-1]) - - streamlines, scalars = zip(*tractogram[:-1]) - assert_arrays_equal(streamlines, self.streamlines[:-1]) - assert_arrays_equal(scalars, self.colors[:-1]) - def test_lazy_tractogram_len(self): streamlines = lambda: (x for x in self.streamlines) - scalars = lambda: (x for x in self.colors) - properties = lambda: (x for x in self.mean_curvature_torsion) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} - with clear_and_catch_warnings(record=True, modules=[module_tractogram]) as w: + modules = [module_tractogram] # Modules for which to catch warnings. + with clear_and_catch_warnings(record=True, modules=modules) as w: warnings.simplefilter("always") # Always trigger warnings. # Calling `len` will create new generators each time. - tractogram = LazyTractogram(streamlines, scalars, properties) + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + assert_true(tractogram._nb_streamlines is None) + # This should produce a warning message. assert_equal(len(tractogram), self.nb_streamlines) + assert_equal(tractogram._nb_streamlines, self.nb_streamlines) assert_equal(len(w), 1) - tractogram = LazyTractogram(streamlines, scalars, properties) - # This should still produce a warning message. + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + + # New instances should still produce a warning message. assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 2) assert_true(issubclass(w[-1].category, UsageWarning)) - # This should *not* produce a warning. + # Calling again 'len' again should *not* produce a warning. assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 2) - with clear_and_catch_warnings(record=True, modules=[module_tractogram]) as w: + with clear_and_catch_warnings(record=True, modules=modules) as w: # Once we iterated through the tractogram, we know the length. - tractogram = LazyTractogram(streamlines, scalars, properties) - assert_true(tractogram.header.nb_streamlines is None) - for streamline in tractogram: - pass - assert_equal(tractogram.header.nb_streamlines, - len(self.streamlines)) - # This should *not* produce a warning. - assert_equal(len(tractogram), len(self.streamlines)) - assert_equal(len(w), 0) + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) - with clear_and_catch_warnings(record=True, modules=[module_tractogram]) as w: - # It first checks if number of tractogram is in the header. - tractogram = LazyTractogram(streamlines, scalars, properties) - tractogram.header.nb_streamlines = 1234 + assert_true(tractogram._nb_streamlines is None) + isiterable(tractogram) # Force to iterate through all streamlines. + assert_equal(tractogram._nb_streamlines, len(self.streamlines)) # This should *not* produce a warning. - assert_equal(len(tractogram), 1234) + assert_equal(len(tractogram), len(self.streamlines)) assert_equal(len(w), 0) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index c4ed299dc5..2f5a641b58 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -228,7 +228,6 @@ def __init__(self, streamlines=None, data_per_streamline=None, data_per_point) self._nb_streamlines = None self._data = None - self._getitem = None self._affine_to_apply = np.eye(4) @classmethod @@ -343,10 +342,7 @@ def _gen_data(): return _gen_data() def __getitem__(self, idx): - if self._getitem is None: - raise AttributeError('`LazyTractogram` does not support indexing.') - - return self._getitem(idx) + raise AttributeError('`LazyTractogram` does not support indexing.') def __iter__(self): i = 0 From 29c05a36847cdfc8b5f2cdfd8730d5c68e06e205 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 Nov 2015 02:48:11 -0500 Subject: [PATCH 37/54] BF: Handle empty LazyTractogram --- nibabel/streamlines/tests/test_tractogram.py | 13 ++++- nibabel/streamlines/tractogram.py | 30 +++++++----- nibabel/streamlines/trk.py | 50 +++++++++++++------- 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index fe7b414ef6..c07b548ae1 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -256,7 +256,18 @@ def test_lazy_tractogram_creation(self): assert_arrays_equal(tractogram.data_per_point['colors'], self.colors) - # Create `LazyTractogram` from a coroutine yielding 3-tuples + def test_lazy_tractogram_create_from(self): + # Create `LazyTractogram` from a coroutine yielding nothing (i.e empty). + _empty_data_gen = lambda: iter([]) + + tractogram = LazyTractogram.create_from(_empty_data_gen) + assert_true(isiterable(tractogram)) + assert_equal(len(tractogram), 0) + assert_arrays_equal(tractogram.streamlines, []) + assert_equal(tractogram.data_per_point, {}) + assert_equal(tractogram.data_per_streamline, {}) + + # Create `LazyTractogram` from a coroutine yielding TractogramItem def _data_gen(): for d in zip(self.streamlines, self.colors, self.mean_curvature, self.mean_color): diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 2f5a641b58..23e14d5685 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -249,21 +249,27 @@ def create_from(cls, data_func): lazy_tractogram = cls() lazy_tractogram._data = data_func - # Set data_per_streamline using data_func - def _gen(key): - return lambda: (t.data_for_streamline[key] for t in data_func()) + try: + first_item = next(data_func()) - data_per_streamline_keys = next(data_func()).data_for_streamline.keys() - for k in data_per_streamline_keys: - lazy_tractogram._data_per_streamline[k] = _gen(k) + # Set data_per_streamline using data_func + def _gen(key): + return lambda: (t.data_for_streamline[key] for t in data_func()) - # Set data_per_point using data_func - def _gen(key): - return lambda: (t.data_for_points[key] for t in data_func()) + data_per_streamline_keys = first_item.data_for_streamline.keys() + for k in data_per_streamline_keys: + lazy_tractogram._data_per_streamline[k] = _gen(k) - data_per_point_keys = next(data_func()).data_for_points.keys() - for k in data_per_point_keys: - lazy_tractogram._data_per_point[k] = _gen(k) + # Set data_per_point using data_func + def _gen(key): + return lambda: (t.data_for_points[key] for t in data_func()) + + data_per_point_keys = first_item.data_for_points.keys() + for k in data_per_point_keys: + lazy_tractogram._data_per_point[k] = _gen(k) + + except StopIteration: + pass return lazy_tractogram diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 27430df421..278b6e2ebe 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -627,24 +627,42 @@ def load(cls, fileobj, lazy_load=False): affine = np.dot(offset, affine) if lazy_load: - # TODO when LazyTractogram has been refactored. - def _apply_transform(trk_reader): - for pts, scals, props in trk_reader: - # TRK's streamlines are in 'voxelmm' space, we send them to voxel space. - pts = pts / trk_reader.header[Field.VOXEL_SIZES] - # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas streamlines returned assume (0,0,0) to be the - # center of the voxel. Thus, streamlines are shifted of half - #a voxel. - pts -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. - trk_reader - yield pts, scals, props + #pts, scals, props = next(iter(trk_reader)) + + data_per_point_slice = {} + if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: + cpt = 0 + for scalar_name in trk_reader.header['scalar_name']: + if len(scalar_name) == 0: + continue + + nb_scalars = np.fromstring(scalar_name[-1], np.int8) + scalar_name = scalar_name.split('\x00')[0] + data_per_point_slice[scalar_name] = slice(cpt, cpt+nb_scalars) + cpt += nb_scalars + + if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: + data_per_point_slice['scalars'] = slice(cpt, trk_reader.header[Field.NB_SCALARS_PER_POINT]) + + data_per_streamline_slice = {} + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + cpt = 0 + for property_name in trk_reader.header['property_name']: + if len(property_name) == 0: + continue + + nb_properties = np.fromstring(property_name[-1], np.int8) + property_name = property_name.split('\x00')[0] + data_per_streamline_slice[property_name] = slice(cpt, cpt+nb_properties) + cpt += nb_properties + + if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: + data_per_streamline_slice['properties'] = slice(cpt, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]) def _read(): for pts, scals, props in trk_reader: - # TODO - data_for_streamline = {} - data_for_points = {} + data_for_streamline = {k: props[:, v] for k, v in data_per_streamline_slice.items()} + data_for_points = {k: scals[:, v] for k, v in data_per_point_slice.items()} yield TractogramItem(pts, data_for_streamline, data_for_points) tractogram = LazyTractogram.create_from(_read) @@ -671,7 +689,6 @@ def _read(): cpt += nb_scalars if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: - #tractogram.data_per_point['scalars'] = scalars clist = CompactList() clist._data = scalars._data[:, cpt:] clist._offsets = scalars._offsets @@ -690,7 +707,6 @@ def _read(): cpt += nb_properties if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: - #tractogram.data_per_streamline['properties'] = properties tractogram.data_per_streamline['properties'] = properties[:, cpt:] # Bring tractogram to RAS+ and mm space From 765053a60a6036695dd84fbef8778e778c45fd4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 Nov 2015 03:42:29 -0500 Subject: [PATCH 38/54] TRK supports LazyTractogram --- nibabel/streamlines/tests/test_tractogram.py | 22 +++++- nibabel/streamlines/tests/test_trk.py | 82 ++++++-------------- nibabel/streamlines/tractogram.py | 2 +- nibabel/streamlines/trk.py | 75 ++++++++++-------- 4 files changed, 89 insertions(+), 92 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index c07b548ae1..493be83301 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -5,7 +5,7 @@ from nibabel.testing import assert_arrays_equal, isiterable from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, assert_array_almost_equal from nibabel.externals.six.moves import zip from .. import tractogram as module_tractogram @@ -346,3 +346,23 @@ def test_lazy_tractogram_len(self): # This should *not* produce a warning. assert_equal(len(tractogram), len(self.streamlines)) assert_equal(len(w), 0) + + def test_lazy_tractogram_apply_affine(self): + streamlines = lambda: (x for x in self.streamlines) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} + + affine = np.eye(4) + scaling = np.array((1, 2, 3), dtype=float) + affine[range(3), range(3)] = scaling + + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + + tractogram.apply_affine(affine) + assert_true(isiterable(tractogram)) + assert_equal(len(tractogram), len(self.streamlines)) + for s1, s2 in zip(tractogram.streamlines, self.streamlines): + assert_array_almost_equal(s1, s2*scaling) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index d731873cea..4abd5b688c 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -48,6 +48,7 @@ def assert_tractogram_equal(t1, t2): def check_tractogram(tractogram, streamlines, data_per_streamline, data_per_point): # Check data + assert_true(isiterable(tractogram)) assert_equal(len(tractogram), len(streamlines)) assert_arrays_equal(tractogram.streamlines, streamlines) @@ -59,8 +60,6 @@ def check_tractogram(tractogram, streamlines, data_per_streamline, data_per_poin assert_arrays_equal(tractogram.data_per_point[key], data_per_point[key]) - assert_true(isiterable(tractogram)) - class TestTRK(unittest.TestCase): @@ -118,31 +117,22 @@ def setUp(self): self.affine = np.eye(4) def test_load_empty_file(self): - trk = TrkFile.load(self.empty_trk_filename, lazy_load=False) - check_tractogram(trk.tractogram, [], {}, {}) - - trk = TrkFile.load(self.empty_trk_filename, lazy_load=True) - # Suppress warning about loading a TRK file in lazy mode with count=0. - #with suppress_warnings(): - check_tractogram(trk.tractogram, [], [], []) + for lazy_load in [False, True]: + trk = TrkFile.load(self.empty_trk_filename, lazy_load=lazy_load) + check_tractogram(trk.tractogram, [], {}, {}) def test_load_simple_file(self): - trk = TrkFile.load(self.simple_trk_filename, lazy_load=False) - check_tractogram(trk.tractogram, self.streamlines, {}, {}) - - # trk = TrkFile.load(self.simple_trk_filename, lazy_load=True) - # check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) + for lazy_load in [False, True]: + trk = TrkFile.load(self.simple_trk_filename, lazy_load=lazy_load) + check_tractogram(trk.tractogram, self.streamlines, {}, {}) def test_load_complex_file(self): - trk = TrkFile.load(self.complex_trk_filename, lazy_load=False) - check_tractogram(trk.tractogram, - self.streamlines, - data_per_point=self.data_per_point, - data_per_streamline=self.data_per_streamline) - - # trk = TrkFile.load(self.complex_trk_filename, lazy_load=True) - # check_tractogram(trk.tractogram, self.nb_streamlines, - # self.streamlines, self.colors, self.mean_curvature_torsion) + for lazy_load in [False, True]: + trk = TrkFile.load(self.complex_trk_filename, lazy_load=lazy_load) + check_tractogram(trk.tractogram, + self.streamlines, + data_per_point=self.data_per_point, + data_per_streamline=self.data_per_streamline) def test_load_file_with_wrong_information(self): trk_file = open(self.simple_trk_filename, 'rb').read() @@ -153,12 +143,6 @@ def test_load_file_with_wrong_information(self): trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) check_tractogram(trk.tractogram, self.streamlines, {}, {}) - # trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) - # with clear_and_catch_warnings(record=True, modules=[base_format]) as w: - # check_tractogram(trk.tractogram, self.streamlines, {}, {}) - # assert_equal(len(w), 1) - # assert_true(issubclass(w[0].category, UsageWarning)) - # Simulate a TRK file where `voxel_order` was not provided. voxel_order = np.zeros(1, dtype="|S3").tostring() new_trk_file = trk_file[:948] + voxel_order + trk_file[948+3:] @@ -286,13 +270,18 @@ def test_write_erroneous_file(self): trk = TrkFile(tractogram, ref=self.affine) assert_raises(IndexError, trk.save, BytesIO()) - def test_load_write_simple_file(self): - trk = TrkFile.load(self.simple_trk_filename, lazy_load=False) - trk_file = BytesIO() - trk.save(trk_file) + def test_load_write_file(self): + for filename in [self.empty_trk_filename, self.simple_trk_filename, self.complex_trk_filename]: + for lazy_load in [False, True]: + trk = TrkFile.load(filename, lazy_load=lazy_load) + trk_file = BytesIO() + trk.save(trk_file) + + loaded_trk = TrkFile.load(filename, lazy_load=False) + assert_tractogram_equal(loaded_trk.tractogram, trk.tractogram) - # trk = TrkFile.load(self.simple_trk_filename, lazy_load=True) - # check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) + trk_file.seek(0, os.SEEK_SET) + #assert_equal(open(filename, 'rb').read(), trk_file.read()) def test_load_write_LPS_file(self): trk = TrkFile.load(self.simple_LPS_trk_filename, lazy_load=False) @@ -312,26 +301,3 @@ def test_load_write_LPS_file(self): trk_file.seek(0, os.SEEK_SET) assert_equal(open(self.simple_LPS_trk_filename, 'rb').read(), trk_file.read()) - - #check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, {}, {}) - - - - # def test_write_file_lazy_tractogram(self): - # streamlines = lambda: (point for point in self.streamlines) - # scalars = lambda: (scalar for scalar in self.colors) - # properties = lambda: (prop for prop in self.mean_curvature_torsion) - - # tractogram = LazyTractogram(streamlines, scalars, properties) - # # No need to manually set `nb_streamlines` in the header since we count - # # them as writing. - # #tractogram.header.nb_streamlines = self.nb_streamlines - - # trk_file = BytesIO() - # trk = TrkFile(tractogram, ref=self.affine) - # trk.save(trk_file) - # trk_file.seek(0, os.SEEK_SET) - - # trk = TrkFile.load(trk_file, ref=None, lazy_load=False) - # check_tractogram(trk.tractogram, self.nb_streamlines, - # self.streamlines, self.colors, self.mean_curvature_torsion) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 23e14d5685..2ae433fd4b 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -287,7 +287,7 @@ def _apply_affine(): for s in streamlines_gen: yield apply_affine(self._affine_to_apply, s) - streamlines_gen = _apply_affine() + return _apply_affine() return streamlines_gen diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 278b6e2ebe..3570afa544 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -295,8 +295,20 @@ def write(self, tractogram): i4_dtype = np.dtype("i4") f4_dtype = np.dtype("f4") + try: + first_item = next(iter(tractogram)) + except StopIteration: + # Empty tractogram + self.header[Field.NB_STREAMLINES] = 0 + self.header[Field.NB_SCALARS_PER_POINT] = 0 + self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = 0 + # Overwrite header with updated one. + self.file.seek(self.beginning, os.SEEK_SET) + self.file.write(self.header.tostring()) + return + # Update the 'property_name' field using 'data_per_streamline' of the tractogram. - data_for_streamline = tractogram[0].data_for_streamline + data_for_streamline = first_item.data_for_streamline data_for_streamline_keys = sorted(data_for_streamline.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] self.header['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') for i, k in enumerate(data_for_streamline_keys): @@ -311,7 +323,7 @@ def write(self, tractogram): self.header['property_name'][i] = property_name # Update the 'scalar_name' field using 'data_per_point' of the tractogram. - data_for_points = tractogram[0].data_for_points + data_for_points = first_item.data_for_points data_for_points_keys = sorted(data_for_points.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') for i, k in enumerate(data_for_points_keys): @@ -636,7 +648,7 @@ def load(cls, fileobj, lazy_load=False): if len(scalar_name) == 0: continue - nb_scalars = np.fromstring(scalar_name[-1], np.int8) + nb_scalars = int(np.fromstring(scalar_name[-1], np.int8)) scalar_name = scalar_name.split('\x00')[0] data_per_point_slice[scalar_name] = slice(cpt, cpt+nb_scalars) cpt += nb_scalars @@ -651,7 +663,7 @@ def load(cls, fileobj, lazy_load=False): if len(property_name) == 0: continue - nb_properties = np.fromstring(property_name[-1], np.int8) + nb_properties = int(np.fromstring(property_name[-1], np.int8)) property_name = property_name.split('\x00')[0] data_per_streamline_slice[property_name] = slice(cpt, cpt+nb_properties) cpt += nb_properties @@ -661,8 +673,8 @@ def load(cls, fileobj, lazy_load=False): def _read(): for pts, scals, props in trk_reader: - data_for_streamline = {k: props[:, v] for k, v in data_per_streamline_slice.items()} data_for_points = {k: scals[:, v] for k, v in data_per_point_slice.items()} + data_for_streamline = {k: props[v] for k, v in data_per_streamline_slice.items()} yield TractogramItem(pts, data_for_streamline, data_for_points) tractogram = LazyTractogram.create_from(_read) @@ -735,8 +747,7 @@ def save(self, fileobj): trk_writer = TrkWriter(fileobj, self.header) trk_writer.write(self.tractogram) - @staticmethod - def pretty_print(fileobj): + def __str__(self): ''' Gets a formatted string of the header of a TRK file. Parameters @@ -751,32 +762,32 @@ def pretty_print(fileobj): info : string Header information relevant to the TRK format. ''' - trk_reader = TrkReader(fileobj) - hdr = trk_reader.header + #trk_reader = TrkReader(fileobj) + hdr = self.header info = "" - info += "MAGIC NUMBER: {0}".format(hdr[Field.MAGIC_NUMBER]) - info += "v.{0}".format(hdr['version']) - info += "dim: {0}".format(hdr[Field.DIMENSIONS]) - info += "voxel_sizes: {0}".format(hdr[Field.VOXEL_SIZES]) - info += "orgin: {0}".format(hdr[Field.ORIGIN]) - info += "nb_scalars: {0}".format(hdr[Field.NB_SCALARS_PER_POINT]) - info += "scalar_name:\n {0}".format("\n".join(hdr['scalar_name'])) - info += "nb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) - info += "property_name:\n {0}".format("\n".join(hdr['property_name'])) - info += "vox_to_world: {0}".format(hdr[Field.VOXEL_TO_RASMM]) - info += "voxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) - info += "image_orientation_patient: {0}".format(hdr['image_orientation_patient']) - info += "pad1: {0}".format(hdr['pad1']) - info += "pad2: {0}".format(hdr['pad2']) - info += "invert_x: {0}".format(hdr['invert_x']) - info += "invert_y: {0}".format(hdr['invert_y']) - info += "invert_z: {0}".format(hdr['invert_z']) - info += "swap_xy: {0}".format(hdr['swap_xy']) - info += "swap_yz: {0}".format(hdr['swap_yz']) - info += "swap_zx: {0}".format(hdr['swap_zx']) - info += "n_count: {0}".format(hdr[Field.NB_STREAMLINES]) - info += "hdr_size: {0}".format(hdr['hdr_size']) - info += "endianess: {0}".format(hdr[Field.ENDIAN]) + info += "\nMAGIC NUMBER: {0}".format(hdr[Field.MAGIC_NUMBER]) + info += "\nv.{0}".format(hdr['version']) + info += "\ndim: {0}".format(hdr[Field.DIMENSIONS]) + info += "\nvoxel_sizes: {0}".format(hdr[Field.VOXEL_SIZES]) + info += "\norgin: {0}".format(hdr[Field.ORIGIN]) + info += "\nnb_scalars: {0}".format(hdr[Field.NB_SCALARS_PER_POINT]) + info += "\nscalar_name:\n {0}".format("\n".join(hdr['scalar_name'])) + info += "\nnb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + info += "\nproperty_name:\n {0}".format("\n".join(hdr['property_name'])) + info += "\nvox_to_world: {0}".format(hdr[Field.VOXEL_TO_RASMM]) + info += "\nvoxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) + info += "\nimage_orientation_patient: {0}".format(hdr['image_orientation_patient']) + info += "\npad1: {0}".format(hdr['pad1']) + info += "\npad2: {0}".format(hdr['pad2']) + info += "\ninvert_x: {0}".format(hdr['invert_x']) + info += "\ninvert_y: {0}".format(hdr['invert_y']) + info += "\ninvert_z: {0}".format(hdr['invert_z']) + info += "\nswap_xy: {0}".format(hdr['swap_xy']) + info += "\nswap_yz: {0}".format(hdr['swap_yz']) + info += "\nswap_zx: {0}".format(hdr['swap_zx']) + info += "\nn_count: {0}".format(hdr[Field.NB_STREAMLINES]) + info += "\nhdr_size: {0}".format(hdr['hdr_size']) + #info += "endianess: {0}".format(hdr[Field.ENDIAN]) return info From a2957af17eafa5ec45ed35442ffd0a07eb8866a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 Nov 2015 21:11:20 -0500 Subject: [PATCH 39/54] Fixed some unit tests --- nibabel/streamlines/tests/data/empty.trk | Bin 1000 -> 1000 bytes nibabel/streamlines/tests/test_streamlines.py | 213 ++++++++++-------- nibabel/streamlines/tests/test_tractogram.py | 32 +++ nibabel/streamlines/tests/test_trk.py | 114 ++++------ 4 files changed, 190 insertions(+), 169 deletions(-) diff --git a/nibabel/streamlines/tests/data/empty.trk b/nibabel/streamlines/tests/data/empty.trk index e78e28403b087f627c298afc88c200dc2d50bcff..fbe087180770b69b1e4ff2c5f3ab1949f4a46a89 100644 GIT binary patch delta 20 ZcmaFC{(_x5B*@X(n}HDoH*())1^_$x1pWX3 delta 20 ccmaFC{(_x5B*@X(n_&UN35Eq5x$iOq07@JO0{{R3 diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index e27e7cc42a..dfd5026102 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -12,6 +12,7 @@ from nibabel.testing import assert_arrays_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true, assert_false +from .test_tractogram import assert_tractogram_equal from ..tractogram import Tractogram, LazyTractogram from ..tractogram_file import TractogramFile from ..tractogram import UsageWarning @@ -20,22 +21,6 @@ DATA_PATH = pjoin(os.path.dirname(__file__), 'data') -def check_tractogram(tractogram, nb_streamlines, streamlines, data_per_streamline, data_per_point): - # Check data - assert_equal(len(tractogram), nb_streamlines) - assert_arrays_equal(tractogram.streamlines, streamlines) - - for key in data_per_streamline.keys(): - assert_arrays_equal(tractogram.data_per_streamline[key], - data_per_streamline[key]) - - for key in data_per_point.keys(): - assert_arrays_equal(tractogram.data_per_point[key], - data_per_point[key]) - - assert_true(isiterable(tractogram)) - - def test_is_supported(): # Emtpy file/string f = BytesIO() @@ -112,121 +97,153 @@ def test_detect_format(): class TestLoadSave(unittest.TestCase): def setUp(self): - self.empty_filenames = [pjoin(DATA_PATH, "empty" + ext) for ext in nib.streamlines.FORMATS.keys()] - self.simple_filenames = [pjoin(DATA_PATH, "simple" + ext) for ext in nib.streamlines.FORMATS.keys()] - self.complex_filenames = [pjoin(DATA_PATH, "complex" + ext) for ext in nib.streamlines.FORMATS.keys()] + self.empty_filenames = [pjoin(DATA_PATH, "empty" + ext) + for ext in nib.streamlines.FORMATS.keys()] + self.simple_filenames = [pjoin(DATA_PATH, "simple" + ext) + for ext in nib.streamlines.FORMATS.keys()] + self.complex_filenames = [pjoin(DATA_PATH, "complex" + ext) + for ext in nib.streamlines.FORMATS.keys()] self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), np.arange(2*3, dtype="f4").reshape((2, 3)), np.arange(5*3, dtype="f4").reshape((5, 3))] + self.fa = [np.array([[0.2]], dtype="f4"), + np.array([[0.3], + [0.4]], dtype="f4"), + np.array([[0.5], + [0.6], + [0.6], + [0.7], + [0.8]], dtype="f4")] + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), np.array([(0, 1, 0)]*2, dtype="f4"), np.array([(0, 0, 1)]*5, dtype="f4")] - self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), - np.array([2.11, 2.22], dtype="f4"), - np.array([3.11, 3.22], dtype="f4")] + self.mean_curvature = [np.array([1.11], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11], dtype="f4")] - self.data_per_point = {'scalars': self.colors} - self.data_per_streamline = {'properties': self.mean_curvature_torsion} + self.mean_torsion = [np.array([1.22], dtype="f4"), + np.array([2.22], dtype="f4"), + np.array([3.22], dtype="f4")] - self.nb_streamlines = len(self.streamlines) - self.nb_scalars_per_point = self.colors[0].shape[1] - self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) - self.to_world_space = np.eye(4) + self.mean_colors = [np.array([1, 0, 0], dtype="f4"), + np.array([0, 1, 0], dtype="f4"), + np.array([0, 0, 1], dtype="f4")] + + self.data_per_point = {'colors': self.colors, + 'fa': self.fa} + self.data_per_streamline = {'mean_curvature': self.mean_curvature, + 'mean_torsion': self.mean_torsion, + 'mean_colors': self.mean_colors} + + self.empty_tractogram = Tractogram() + self.simple_tractogram = Tractogram(self.streamlines) + self.complex_tractogram = Tractogram(self.streamlines, + self.data_per_streamline, + self.data_per_point) + + #self.nb_streamlines = len(self.streamlines) + #self.nb_scalars_per_point = self.colors[0].shape[1] + #self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) + #self.to_world_space = np.eye(4) def test_load_empty_file(self): - for empty_filename in self.empty_filenames: - tractogram_file = nib.streamlines.load(empty_filename, - lazy_load=False) - assert_true(isinstance(tractogram_file, TractogramFile)) - assert_true(type(tractogram_file.tractogram), Tractogram) - check_tractogram(tractogram_file.tractogram, 0, [], {}, {}) + for lazy_load in [False, True]: + for empty_filename in self.empty_filenames: + tfile = nib.streamlines.load(empty_filename, + lazy_load=lazy_load) + assert_true(isinstance(tfile, TractogramFile)) + + if lazy_load: + assert_true(type(tfile.tractogram), Tractogram) + else: + assert_true(type(tfile.tractogram), LazyTractogram) + + assert_tractogram_equal(tfile.tractogram, + self.empty_tractogram) def test_load_simple_file(self): - for simple_filename in self.simple_filenames: - tractogram_file = nib.streamlines.load(simple_filename, - lazy_load=False) - assert_true(isinstance(tractogram_file, TractogramFile)) - assert_true(type(tractogram_file.tractogram), Tractogram) - check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, {}, {}) - - # # Test lazy_load - # tractogram_file = nib.streamlines.load(simple_filename, - # lazy_load=True) - - # assert_true(isinstance(tractogram_file, TractogramFile)) - # assert_true(type(tractogram_file.tractogram), LazyTractogram) - # check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - # self.streamlines, {}, {}) + for lazy_load in [False, True]: + for simple_filename in self.simple_filenames: + tfile = nib.streamlines.load(simple_filename, + lazy_load=lazy_load) + assert_true(isinstance(tfile, TractogramFile)) + + if lazy_load: + assert_true(type(tfile.tractogram), Tractogram) + else: + assert_true(type(tfile.tractogram), LazyTractogram) + + assert_tractogram_equal(tfile.tractogram, + self.simple_tractogram) def test_load_complex_file(self): - for complex_filename in self.complex_filenames: - file_format = nib.streamlines.detect_format(complex_filename) - - data_per_point = {} - if file_format.support_data_per_point(): - data_per_point = {'scalars': self.colors} - - data_per_streamline = [] - if file_format.support_data_per_streamline(): - data_per_streamline = {'properties': self.mean_curvature_torsion} - - tractogram_file = nib.streamlines.load(complex_filename, - lazy_load=False) - assert_true(isinstance(tractogram_file, TractogramFile)) - assert_true(type(tractogram_file.tractogram), Tractogram) - check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, - data_per_streamline=data_per_streamline, - data_per_point=data_per_point) - - # # Test lazy_load - # tractogram_file = nib.streamlines.load(complex_filename, - # lazy_load=True) - # assert_true(isinstance(tractogram_file, TractogramFile)) - # assert_true(type(tractogram_file.tractogram), LazyTractogram) - # check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - # self.streamlines, - # data_per_streamline=data_per_streamline, - # data_per_point=data_per_point) + for lazy_load in [False, True]: + for complex_filename in self.complex_filenames: + tfile = nib.streamlines.load(complex_filename, + lazy_load=lazy_load) + assert_true(isinstance(tfile, TractogramFile)) + + if lazy_load: + assert_true(type(tfile.tractogram), Tractogram) + else: + assert_true(type(tfile.tractogram), LazyTractogram) + + tractogram = Tractogram(self.streamlines) + + if tfile.support_data_per_point(): + tractogram.data_per_point = self.data_per_point + + if tfile.support_data_per_streamline(): + tractogram.data_per_streamline = self.data_per_streamline + + assert_tractogram_equal(tfile.tractogram, + tractogram) + + def test_save_empty_file(self): + tractogram = Tractogram() + for ext, cls in nib.streamlines.FORMATS.items(): + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + nib.streamlines.save_tractogram(tractogram, f.name) + tfile = nib.streamlines.load(f, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_simple_file(self): tractogram = Tractogram(self.streamlines) for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: nib.streamlines.save_tractogram(tractogram, f.name) - tractogram_file = nib.streamlines.load(f, lazy_load=False) - check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, {}, {}) + tfile = nib.streamlines.load(f, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_complex_file(self): - tractogram = Tractogram(self.streamlines, - data_per_streamline=self.data_per_streamline, - data_per_point=self.data_per_point) + complex_tractogram = Tractogram(self.streamlines, + self.data_per_streamline, + self.data_per_point) + for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: with clear_and_catch_warnings(record=True, modules=[trk]) as w: - nib.streamlines.save_tractogram(tractogram, f.name) + nib.streamlines.save_tractogram(complex_tractogram, f.name) - # If streamlines format does not support saving data per point - # or data per streamline, a warning message should be issued. - if not (cls.support_data_per_point() and cls.support_data_per_streamline()): + # If streamlines format does not support saving data per + # point or data per streamline, a warning message should + # be issued. + if not (cls.support_data_per_point() + and cls.support_data_per_streamline()): assert_equal(len(w), 1) assert_true(issubclass(w[0].category, UsageWarning)) - data_per_point = {} + tractogram = Tractogram(self.streamlines) + if cls.support_data_per_point(): - data_per_point = self.data_per_point + tractogram.data_per_point = self.data_per_point - data_per_streamline = [] if cls.support_data_per_streamline(): - data_per_streamline = self.data_per_streamline + tractogram.data_per_streamline = self.data_per_streamline - tractogram_file = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) - check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, - data_per_streamline=data_per_streamline, - data_per_point=data_per_point) + tfile = nib.streamlines.load(f, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 493be83301..ac458bae32 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -13,6 +13,38 @@ from ..tractogram import TractogramItem, Tractogram, LazyTractogram +# def check_tractogram(tractogram, streamlines, +# data_per_streamline, data_per_point): +# # Check data +# assert_true(isiterable(tractogram)) +# assert_equal(len(tractogram), len(streamlines)) +# assert_arrays_equal(tractogram.streamlines, streamlines) + +# for key in data_per_streamline.keys(): +# assert_arrays_equal(tractogram.data_per_streamline[key], +# data_per_streamline[key]) + +# for key in data_per_point.keys(): +# assert_arrays_equal(tractogram.data_per_point[key], +# data_per_point[key]) + + +def assert_tractogram_equal(t1, t2): + assert_true(isiterable(t1)) + assert_equal(len(t1), len(t2)) + assert_arrays_equal(t1.streamlines, t2.streamlines) + + assert_equal(len(t1.data_per_streamline), len(t2.data_per_streamline)) + for key in t1.data_per_streamline.keys(): + assert_arrays_equal(t1.data_per_streamline[key], + t2.data_per_streamline[key]) + + assert_equal(len(t1.data_per_point), len(t2.data_per_point)) + for key in t1.data_per_point.keys(): + assert_arrays_equal(t1.data_per_point[key], + t2.data_per_point[key]) + + class TestTractogramItem(unittest.TestCase): def test_creating_tractogram_item(self): diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 4abd5b688c..cbd0ffc785 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -8,6 +8,7 @@ from nibabel.testing import assert_arrays_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true +from .test_tractogram import assert_tractogram_equal from .. import base_format from ..tractogram import Tractogram, LazyTractogram from ..base_format import DataError, HeaderError, HeaderWarning#, UsageWarning @@ -31,36 +32,6 @@ def assert_header_equal(h1, h2): assert_equal(header1, header2) -def assert_tractogram_equal(t1, t2): - assert_equal(len(t1), len(t2)) - assert_arrays_equal(t1.streamlines, t2.streamlines) - - assert_equal(len(t1.data_per_streamline), len(t2.data_per_streamline)) - for key in t1.data_per_streamline.keys(): - assert_arrays_equal(t1.data_per_streamline[key], - t2.data_per_streamline[key]) - - assert_equal(len(t1.data_per_point), len(t2.data_per_point)) - for key in t1.data_per_point.keys(): - assert_arrays_equal(t1.data_per_point[key], - t2.data_per_point[key]) - - -def check_tractogram(tractogram, streamlines, data_per_streamline, data_per_point): - # Check data - assert_true(isiterable(tractogram)) - assert_equal(len(tractogram), len(streamlines)) - assert_arrays_equal(tractogram.streamlines, streamlines) - - for key in data_per_streamline.keys(): - assert_arrays_equal(tractogram.data_per_streamline[key], - data_per_streamline[key]) - - for key in data_per_point.keys(): - assert_arrays_equal(tractogram.data_per_point[key], - data_per_point[key]) - - class TestTRK(unittest.TestCase): def setUp(self): @@ -101,38 +72,32 @@ def setUp(self): np.array([0, 1, 0], dtype="f4"), np.array([0, 0, 1], dtype="f4")] - self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), - np.array([2.11, 2.22], dtype="f4"), - np.array([3.11, 3.22], dtype="f4")] - self.data_per_point = {'colors': self.colors, 'fa': self.fa} self.data_per_streamline = {'mean_curvature': self.mean_curvature, 'mean_torsion': self.mean_torsion, 'mean_colors': self.mean_colors} - self.nb_streamlines = len(self.streamlines) - self.nb_scalars_per_point = self.colors[0].shape[1] - self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) - self.affine = np.eye(4) + self.empty_tractogram = Tractogram() + self.simple_tractogram = Tractogram(self.streamlines) + self.complex_tractogram = Tractogram(self.streamlines, + self.data_per_streamline, + self.data_per_point) def test_load_empty_file(self): for lazy_load in [False, True]: trk = TrkFile.load(self.empty_trk_filename, lazy_load=lazy_load) - check_tractogram(trk.tractogram, [], {}, {}) + assert_tractogram_equal(trk.tractogram, self.empty_tractogram) def test_load_simple_file(self): for lazy_load in [False, True]: trk = TrkFile.load(self.simple_trk_filename, lazy_load=lazy_load) - check_tractogram(trk.tractogram, self.streamlines, {}, {}) + assert_tractogram_equal(trk.tractogram, self.simple_tractogram) def test_load_complex_file(self): for lazy_load in [False, True]: trk = TrkFile.load(self.complex_trk_filename, lazy_load=lazy_load) - check_tractogram(trk.tractogram, - self.streamlines, - data_per_point=self.data_per_point, - data_per_streamline=self.data_per_streamline) + assert_tractogram_equal(trk.tractogram, self.complex_tractogram) def test_load_file_with_wrong_information(self): trk_file = open(self.simple_trk_filename, 'rb').read() @@ -141,7 +106,7 @@ def test_load_file_with_wrong_information(self): count = np.array(0, dtype="int32").tostring() new_trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) - check_tractogram(trk.tractogram, self.streamlines, {}, {}) + assert_tractogram_equal(trk.tractogram, self.simple_tractogram) # Simulate a TRK file where `voxel_order` was not provided. voxel_order = np.zeros(1, dtype="|S3").tostring() @@ -162,23 +127,39 @@ def test_load_file_with_wrong_information(self): new_trk_file = trk_file[:996] + hdr_size + trk_file[996+4:] assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + def test_write_empty_file(self): + tractogram = Tractogram() + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + loaded_trk = TrkFile.load(trk_file) + assert_tractogram_equal(loaded_trk.tractogram, tractogram) + + loaded_trk_orig = TrkFile.load(self.empty_trk_filename) + assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) + + trk_file.seek(0, os.SEEK_SET) + assert_equal(trk_file.read(), open(self.empty_trk_filename, 'rb').read()) + def test_write_simple_file(self): tractogram = Tractogram(self.streamlines) trk_file = BytesIO() - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file) - check_tractogram(loaded_trk.tractogram, - self.streamlines, {}, {}) + assert_tractogram_equal(loaded_trk.tractogram, tractogram) loaded_trk_orig = TrkFile.load(self.simple_trk_filename) assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) - assert_equal(open(self.simple_trk_filename, 'rb').read(), trk_file.read()) + assert_equal(trk_file.read(), open(self.simple_trk_filename, 'rb').read()) def test_write_complex_file(self): # With scalars @@ -186,30 +167,24 @@ def test_write_complex_file(self): data_per_point=self.data_per_point) trk_file = BytesIO() - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file, lazy_load=False) - check_tractogram(loaded_trk.tractogram, - self.streamlines, - data_per_streamline={}, - data_per_point=self.data_per_point) + assert_tractogram_equal(loaded_trk.tractogram, tractogram) # With properties tractogram = Tractogram(self.streamlines, data_per_streamline=self.data_per_streamline) - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) trk_file = BytesIO() trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file, lazy_load=False) - check_tractogram(loaded_trk.tractogram, - self.streamlines, - data_per_streamline=self.data_per_streamline, - data_per_point={}) + assert_tractogram_equal(loaded_trk.tractogram, tractogram) # With scalars and properties tractogram = Tractogram(self.streamlines, @@ -217,21 +192,18 @@ def test_write_complex_file(self): data_per_streamline=self.data_per_streamline) trk_file = BytesIO() - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file, lazy_load=False) - check_tractogram(loaded_trk.tractogram, - self.streamlines, - data_per_streamline=self.data_per_streamline, - data_per_point=self.data_per_point) + assert_tractogram_equal(loaded_trk.tractogram, tractogram) loaded_trk_orig = TrkFile.load(self.complex_trk_filename) assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) - assert_equal(open(self.complex_trk_filename, 'rb').read(), trk_file.read()) + assert_equal(trk_file.read(), open(self.complex_trk_filename, 'rb').read()) def test_write_erroneous_file(self): # No scalars for every points @@ -241,7 +213,7 @@ def test_write_erroneous_file(self): tractogram = Tractogram(self.streamlines, data_per_point={'scalars': scalars}) - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) assert_raises(DataError, trk.save, BytesIO()) # No scalars for every streamlines @@ -250,7 +222,7 @@ def test_write_erroneous_file(self): tractogram = Tractogram(self.streamlines, data_per_point={'scalars': scalars}) - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) assert_raises(IndexError, trk.save, BytesIO()) # Inconsistent number of properties @@ -259,7 +231,7 @@ def test_write_erroneous_file(self): np.array([3.11, 3.22], dtype="f4")] tractogram = Tractogram(self.streamlines, data_per_streamline={'properties': properties}) - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) assert_raises(DataError, trk.save, BytesIO()) # No properties for every streamlines @@ -267,7 +239,7 @@ def test_write_erroneous_file(self): np.array([2.11, 2.22], dtype="f4")] tractogram = Tractogram(self.streamlines, data_per_streamline={'properties': properties}) - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) assert_raises(IndexError, trk.save, BytesIO()) def test_load_write_file(self): @@ -281,7 +253,7 @@ def test_load_write_file(self): assert_tractogram_equal(loaded_trk.tractogram, trk.tractogram) trk_file.seek(0, os.SEEK_SET) - #assert_equal(open(filename, 'rb').read(), trk_file.read()) + #assert_equal(trk_file.read(), open(filename, 'rb').read()) def test_load_write_LPS_file(self): trk = TrkFile.load(self.simple_LPS_trk_filename, lazy_load=False) @@ -300,4 +272,4 @@ def test_load_write_LPS_file(self): assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) - assert_equal(open(self.simple_LPS_trk_filename, 'rb').read(), trk_file.read()) + assert_equal(trk_file.read(), open(self.simple_LPS_trk_filename, 'rb').read()) From 03a6229a74913d2323686ab274bf3472687fc00e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 24 Nov 2015 00:25:23 -0500 Subject: [PATCH 40/54] BF: limit property and scalar names to 18 characters --- nibabel/streamlines/tests/test_tractogram.py | 16 -- nibabel/streamlines/tests/test_trk.py | 134 +++++++++++--- nibabel/streamlines/trk.py | 181 ++++++------------- 3 files changed, 165 insertions(+), 166 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index ac458bae32..64b0e150ab 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -13,22 +13,6 @@ from ..tractogram import TractogramItem, Tractogram, LazyTractogram -# def check_tractogram(tractogram, streamlines, -# data_per_streamline, data_per_point): -# # Check data -# assert_true(isiterable(tractogram)) -# assert_equal(len(tractogram), len(streamlines)) -# assert_arrays_equal(tractogram.streamlines, streamlines) - -# for key in data_per_streamline.keys(): -# assert_arrays_equal(tractogram.data_per_streamline[key], -# data_per_streamline[key]) - -# for key in data_per_point.keys(): -# assert_arrays_equal(tractogram.data_per_point[key], -# data_per_point[key]) - - def assert_tractogram_equal(t1, t2): assert_true(isiterable(t1)) assert_equal(len(t1), len(t2)) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index cbd0ffc785..4b9252e00d 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -13,7 +13,7 @@ from ..tractogram import Tractogram, LazyTractogram from ..base_format import DataError, HeaderError, HeaderWarning#, UsageWarning -#from .. import trk +from .. import trk as trk_module from ..trk import TrkFile, header_2_dtype DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') @@ -111,7 +111,7 @@ def test_load_file_with_wrong_information(self): # Simulate a TRK file where `voxel_order` was not provided. voxel_order = np.zeros(1, dtype="|S3").tostring() new_trk_file = trk_file[:948] + voxel_order + trk_file[948+3:] - with clear_and_catch_warnings(record=True, modules=[trk]) as w: + with clear_and_catch_warnings(record=True, modules=[trk_module]) as w: TrkFile.load(BytesIO(new_trk_file)) assert_equal(len(w), 1) assert_true(issubclass(w[0].category, HeaderWarning)) @@ -135,11 +135,11 @@ def test_write_empty_file(self): trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_trk = TrkFile.load(trk_file) - assert_tractogram_equal(loaded_trk.tractogram, tractogram) + new_trk = TrkFile.load(trk_file) + assert_tractogram_equal(new_trk.tractogram, tractogram) - loaded_trk_orig = TrkFile.load(self.empty_trk_filename) - assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) + new_trk_orig = TrkFile.load(self.empty_trk_filename) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), open(self.empty_trk_filename, 'rb').read()) @@ -152,11 +152,11 @@ def test_write_simple_file(self): trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_trk = TrkFile.load(trk_file) - assert_tractogram_equal(loaded_trk.tractogram, tractogram) + new_trk = TrkFile.load(trk_file) + assert_tractogram_equal(new_trk.tractogram, tractogram) - loaded_trk_orig = TrkFile.load(self.simple_trk_filename) - assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) + new_trk_orig = TrkFile.load(self.simple_trk_filename) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), open(self.simple_trk_filename, 'rb').read()) @@ -171,8 +171,8 @@ def test_write_complex_file(self): trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_trk = TrkFile.load(trk_file, lazy_load=False) - assert_tractogram_equal(loaded_trk.tractogram, tractogram) + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) # With properties tractogram = Tractogram(self.streamlines, @@ -183,8 +183,8 @@ def test_write_complex_file(self): trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_trk = TrkFile.load(trk_file, lazy_load=False) - assert_tractogram_equal(loaded_trk.tractogram, tractogram) + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) # With scalars and properties tractogram = Tractogram(self.streamlines, @@ -196,11 +196,11 @@ def test_write_complex_file(self): trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_trk = TrkFile.load(trk_file, lazy_load=False) - assert_tractogram_equal(loaded_trk.tractogram, tractogram) + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) - loaded_trk_orig = TrkFile.load(self.complex_trk_filename) - assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) + new_trk_orig = TrkFile.load(self.complex_trk_filename) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), open(self.complex_trk_filename, 'rb').read()) @@ -249,8 +249,8 @@ def test_load_write_file(self): trk_file = BytesIO() trk.save(trk_file) - loaded_trk = TrkFile.load(filename, lazy_load=False) - assert_tractogram_equal(loaded_trk.tractogram, trk.tractogram) + new_trk = TrkFile.load(filename, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, trk.tractogram) trk_file.seek(0, os.SEEK_SET) #assert_equal(trk_file.read(), open(filename, 'rb').read()) @@ -263,13 +263,97 @@ def test_load_write_LPS_file(self): trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_trk = TrkFile.load(trk_file) + new_trk = TrkFile.load(trk_file) - assert_header_equal(loaded_trk.header, trk.header) - assert_tractogram_equal(loaded_trk.tractogram, trk.tractogram) + assert_header_equal(new_trk.header, trk.header) + assert_tractogram_equal(new_trk.tractogram, trk.tractogram) - loaded_trk_orig = TrkFile.load(self.simple_LPS_trk_filename) - assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) + new_trk_orig = TrkFile.load(self.simple_LPS_trk_filename) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), open(self.simple_LPS_trk_filename, 'rb').read()) + + def test_write_too_many_scalars_and_properties(self): + # TRK supports up to 10 data_per_point. + data_per_point = {} + for i in range(10): + data_per_point['#{0}'.format(i)] = self.fa + + tractogram = Tractogram(self.streamlines, + data_per_point=data_per_point) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + # More than 10 data_per_point should raise an error. + data_per_point['#{0}'.format(i+1)] = self.fa + + tractogram = Tractogram(self.streamlines, + data_per_point=data_per_point) + + trk = TrkFile(tractogram) + assert_raises(ValueError, trk.save, BytesIO()) + + # TRK supports up to 10 data_per_streamline. + data_per_streamline = {} + for i in range(10): + data_per_streamline['#{0}'.format(i)] = self.mean_torsion + + tractogram = Tractogram(self.streamlines, + data_per_streamline=data_per_streamline) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + # More than 10 data_per_streamline should raise an error. + data_per_streamline['#{0}'.format(i+1)] = self.mean_torsion + + tractogram = Tractogram(self.streamlines, + data_per_streamline=data_per_streamline) + + trk = TrkFile(tractogram) + assert_raises(ValueError, trk.save, BytesIO()) + + def test_write_scalars_and_properties_name_too_long(self): + # TRK supports data_per_point name up to 20 characters. + # However, we reserve the last two characters to store + # the number of values associated to each data_per_point. + # So in reality we allow name of 18 characters, otherwise + # the name is truncated and warning is issue. + for nb_chars in range(22): + data_per_point = {'A'*nb_chars: self.fa} + tractogram = Tractogram(self.streamlines, + data_per_point=data_per_point) + + trk = TrkFile(tractogram) + if nb_chars > 18: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) + + # TRK supports data_per_streamline name up to 20 characters. + # However, we reserve the last two characters to store + # the number of values associated to each data_per_streamline. + # So in reality we allow name of 18 characters, otherwise + # the name is truncated and warning is issue. + for nb_chars in range(22): + data_per_streamline = {'A'*nb_chars: self.mean_torsion} + tractogram = Tractogram(self.streamlines, + data_per_streamline=data_per_streamline) + + trk = TrkFile(tractogram) + if nb_chars > 18: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 3570afa544..5caa098abf 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -21,7 +21,6 @@ from .tractogram import TractogramItem, Tractogram, LazyTractogram from .header import Field -from .utils import get_affine_from_reference MAX_NB_NAMED_SCALARS_PER_POINT = 10 MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE = 10 @@ -217,45 +216,6 @@ def create_empty_header(cls): return header - # def __init__(self, fileobj, header): - # self.header = self.create_empty_header() - - # # Override hdr's fields by those contained in `header`. - # for k, v in header.extra.items(): - # if k in header_2_dtype.fields.keys(): - # self.header[k] = v - - # # TODO: Fix that ugly patch. - # # Because the assignment operator on ndarray of string only copy the - # # first entry, we have to do it explicitly! - # if "property_name" in header.extra: - # self.header["property_name"][:] = header.extra["property_name"][:] - - # if "scalar_name" in header.extra: - # self.header["scalar_name"][:] = header.extra["scalar_name"][:] - - # self.header[Field.NB_STREAMLINES] = 0 - # if header.nb_streamlines is not None: - # self.header[Field.NB_STREAMLINES] = header.nb_streamlines - - # self.header[Field.NB_SCALARS_PER_POINT] = header.nb_scalars_per_point - # self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = header.nb_properties_per_streamline - # self.header[Field.VOXEL_SIZES] = header.voxel_sizes - # self.header[Field.VOXEL_TO_RASMM] = header.to_world_space - # self.header[Field.VOXEL_ORDER] = header.voxel_order - - # # Keep counts for correcting incoherent fields or warn. - # self.nb_streamlines = 0 - # self.nb_points = 0 - # self.nb_scalars = 0 - # self.nb_properties = 0 - - # # Write header - # self.file = Opener(fileobj, mode="wb") - # # Keep track of the beginning of the header. - # self.beginning = self.file.tell() - # self.file.write(self.header.tostring()) - def __init__(self, fileobj, header): self.header = self.create_empty_header() @@ -264,15 +224,6 @@ def __init__(self, fileobj, header): if k in header_2_dtype.fields.keys(): self.header[k] = v - # self.header[Field.NB_STREAMLINES] = 0 - # if header.nb_streamlines is not None: - # self.header[Field.NB_STREAMLINES] = header.nb_streamlines - - # self.header[Field.NB_SCALARS_PER_POINT] = header.nb_scalars_per_point - # self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = header.nb_properties_per_streamline - # self.header[Field.VOXEL_SIZES] = header.voxel_sizes - # self.header[Field.VOXEL_TO_RASMM] = header.to_world_space - # By default, the voxel order is LPS. # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates if self.header[Field.VOXEL_ORDER] == b"": @@ -309,35 +260,34 @@ def write(self, tractogram): # Update the 'property_name' field using 'data_per_streamline' of the tractogram. data_for_streamline = first_item.data_for_streamline - data_for_streamline_keys = sorted(data_for_streamline.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] + if len(data_for_streamline) > MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: + raise ValueError("Can only store {0} named data_per_streamline (properties).".format(MAX_NB_NAMED_SCALARS_PER_POINT)) + + data_for_streamline_keys = sorted(data_for_streamline.keys()) self.header['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') for i, k in enumerate(data_for_streamline_keys): - if i >= MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: - warnings.warn(("Can only store {0} named properties: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) - - if len(k) > 19: - warnings.warn(("Property name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) + if len(k) > 18: + raise ValueError("Property name '{0}' too long (max 18 char.)".format(k)) v = data_for_streamline[k] - property_name = k[:19].ljust(19, '\x00') + np.array(v.shape[0], dtype=np.int8).tostring() + property_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(v.shape[0], dtype=np.int8).tostring() self.header['property_name'][i] = property_name # Update the 'scalar_name' field using 'data_per_point' of the tractogram. data_for_points = first_item.data_for_points - data_for_points_keys = sorted(data_for_points.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] + if len(data_for_points) > MAX_NB_NAMED_SCALARS_PER_POINT: + raise ValueError("Can only store {0} named data_per_point (scalars).".format(MAX_NB_NAMED_SCALARS_PER_POINT)) + + data_for_points_keys = sorted(data_for_points.keys()) self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') for i, k in enumerate(data_for_points_keys): - if i >= MAX_NB_NAMED_SCALARS_PER_POINT: - warnings.warn(("Can only store {0} named scalars: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) - - if len(k) > 19: - warnings.warn(("Scalar name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) + if len(k) > 18: + raise ValueError("Scalar name '{0}' too long (max 18 char.)".format(k)) v = data_for_points[k] - scalar_name = k[:19].ljust(19, '\x00') + np.array(v.shape[1], dtype=np.int8).tostring() + scalar_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(v.shape[1], dtype=np.int8).tostring() self.header['scalar_name'][i] = scalar_name - # `Tractogram` streamlines are in RAS+ and mm space, we will compute # the affine matrix that will bring them back to 'voxelmm' as required # by the TRK format. @@ -638,39 +588,48 @@ def load(cls, fileobj, lazy_load=False): offset[:-1, -1] -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. affine = np.dot(offset, affine) - if lazy_load: - #pts, scals, props = next(iter(trk_reader)) - data_per_point_slice = {} - if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: - cpt = 0 - for scalar_name in trk_reader.header['scalar_name']: - if len(scalar_name) == 0: - continue + # Find scalars and properties name + data_per_point_slice = {} + if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: + cpt = 0 + for scalar_name in trk_reader.header['scalar_name']: + if len(scalar_name) == 0: + continue + # Check if we encoded the number of values we stocked for this scalar name. + nb_scalars = 1 + if scalar_name[-2] == '\x00' and scalar_name[-1] != '\x00': nb_scalars = int(np.fromstring(scalar_name[-1], np.int8)) - scalar_name = scalar_name.split('\x00')[0] - data_per_point_slice[scalar_name] = slice(cpt, cpt+nb_scalars) - cpt += nb_scalars - if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: - data_per_point_slice['scalars'] = slice(cpt, trk_reader.header[Field.NB_SCALARS_PER_POINT]) + scalar_name = scalar_name.split('\x00')[0] + data_per_point_slice[scalar_name] = slice(cpt, cpt+nb_scalars) + cpt += nb_scalars - data_per_streamline_slice = {} - if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: - cpt = 0 - for property_name in trk_reader.header['property_name']: - if len(property_name) == 0: - continue + if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: + data_per_point_slice['scalars'] = slice(cpt, trk_reader.header[Field.NB_SCALARS_PER_POINT]) + data_per_streamline_slice = {} + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + cpt = 0 + for property_name in trk_reader.header['property_name']: + if len(property_name) == 0: + continue + + # Check if we encoded the number of values we stocked for this property name. + nb_properties = 1 + if property_name[-2] == '\x00' and property_name[-1] != '\x00': nb_properties = int(np.fromstring(property_name[-1], np.int8)) - property_name = property_name.split('\x00')[0] - data_per_streamline_slice[property_name] = slice(cpt, cpt+nb_properties) - cpt += nb_properties - if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: - data_per_streamline_slice['properties'] = slice(cpt, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]) + property_name = property_name.split('\x00')[0] + data_per_streamline_slice[property_name] = slice(cpt, cpt+nb_properties) + cpt += nb_properties + + if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: + data_per_streamline_slice['properties'] = slice(cpt, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]) + + if lazy_load: def _read(): for pts, scals, props in trk_reader: data_for_points = {k: scals[:, v] for k, v in data_per_point_slice.items()} @@ -683,43 +642,15 @@ def _read(): streamlines, scalars, properties = create_compactlist_from_generator(trk_reader) tractogram = Tractogram(streamlines) - if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: - cpt = 0 - for scalar_name in trk_reader.header['scalar_name']: - if len(scalar_name) == 0: - continue - - nb_scalars = np.fromstring(scalar_name[-1], np.int8) - - clist = CompactList() - clist._data = scalars._data[:, cpt:cpt+nb_scalars] - clist._offsets = scalars._offsets - clist._lengths = scalars._lengths - - scalar_name = scalar_name.split('\x00')[0] - tractogram.data_per_point[scalar_name] = clist - cpt += nb_scalars - - if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: - clist = CompactList() - clist._data = scalars._data[:, cpt:] - clist._offsets = scalars._offsets - clist._lengths = scalars._lengths - tractogram.data_per_point['scalars'] = clist - - if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: - cpt = 0 - for property_name in trk_reader.header['property_name']: - if len(property_name) == 0: - continue - - nb_properties = np.fromstring(property_name[-1], np.int8) - property_name = property_name.split('\x00')[0] - tractogram.data_per_streamline[property_name] = properties[:, cpt:cpt+nb_properties] - cpt += nb_properties - - if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: - tractogram.data_per_streamline['properties'] = properties[:, cpt:] + for scalar_name, slice_ in data_per_point_slice.items(): + clist = CompactList() + clist._data = scalars._data[:, slice_] + clist._offsets = scalars._offsets + clist._lengths = scalars._lengths + tractogram.data_per_point[scalar_name] = clist + + for property_name, slice_ in data_per_streamline_slice.items(): + tractogram.data_per_streamline[property_name] = properties[:, slice_] # Bring tractogram to RAS+ and mm space tractogram.apply_affine(affine.astype(np.float32)) From 438eebab9d24bfaae5af706af4e4486a5a7659b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 24 Nov 2015 10:48:14 -0500 Subject: [PATCH 41/54] BF: Only store the nb of values in the property or scalar name if it is greater than one --- nibabel/streamlines/tests/data/complex.trk | Bin 1296 -> 1296 bytes nibabel/streamlines/tests/test_trk.py | 27 ++++++++++++++++--- nibabel/streamlines/trk.py | 30 +++++++++++++++------ 3 files changed, 46 insertions(+), 11 deletions(-) diff --git a/nibabel/streamlines/tests/data/complex.trk b/nibabel/streamlines/tests/data/complex.trk index 8aa099becd3bf9436d6cd0cc4ded087c3e794f8b..0a874ea6e74a7646c163b437f5423443a4816105 100644 GIT binary patch delta 51 zcmbQhHGyk_?_>d{g%jUt@-pP6Cg#PL 20: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) + # TRK supports data_per_streamline name up to 20 characters. # However, we reserve the last two characters to store # the number of values associated to each data_per_streamline. # So in reality we allow name of 18 characters, otherwise # the name is truncated and warning is issue. for nb_chars in range(22): - data_per_streamline = {'A'*nb_chars: self.mean_torsion} + data_per_streamline = {'A'*nb_chars: self.mean_colors} tractogram = Tractogram(self.streamlines, data_per_streamline=data_per_streamline) @@ -357,3 +368,13 @@ def test_write_scalars_and_properties_name_too_long(self): assert_raises(ValueError, trk.save, BytesIO()) else: trk.save(BytesIO()) + + data_per_streamline = {'A'*nb_chars: self.mean_torsion} + tractogram = Tractogram(self.streamlines, + data_per_streamline=data_per_streamline) + + trk = TrkFile(tractogram) + if nb_chars > 20: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 5caa098abf..3922f41e99 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -266,11 +266,18 @@ def write(self, tractogram): data_for_streamline_keys = sorted(data_for_streamline.keys()) self.header['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') for i, k in enumerate(data_for_streamline_keys): - if len(k) > 18: - raise ValueError("Property name '{0}' too long (max 18 char.)".format(k)) + nb_values = data_for_streamline[k].shape[0] + + if len(k) > 20: + raise ValueError("Property name '{0}' is too long (max 20 char.)".format(k)) + elif len(k) > 18 and nb_values > 1: + raise ValueError("Property name '{0}' is too long (need to be less than 18 characters when storing more than one value".format(k)) + + property_name = k + if nb_values > 1: + # Use the last to bytes of the name to store the nb of values associated to this data_for_streamline. + property_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(nb_values, dtype=np.int8).tostring() - v = data_for_streamline[k] - property_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(v.shape[0], dtype=np.int8).tostring() self.header['property_name'][i] = property_name # Update the 'scalar_name' field using 'data_per_point' of the tractogram. @@ -281,11 +288,18 @@ def write(self, tractogram): data_for_points_keys = sorted(data_for_points.keys()) self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') for i, k in enumerate(data_for_points_keys): - if len(k) > 18: - raise ValueError("Scalar name '{0}' too long (max 18 char.)".format(k)) + nb_values = data_for_points[k].shape[1] + + if len(k) > 20: + raise ValueError("Scalar name '{0}' is too long (max 18 char.)".format(k)) + elif len(k) > 18 and nb_values > 1: + raise ValueError("Scalar name '{0}' is too long (need to be less than 18 characters when storing more than one value".format(k)) + + scalar_name = k + if nb_values > 1: + # Use the last to bytes of the name to store the nb of values associated to this data_for_streamline. + scalar_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(nb_values, dtype=np.int8).tostring() - v = data_for_points[k] - scalar_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(v.shape[1], dtype=np.int8).tostring() self.header['scalar_name'][i] = scalar_name # `Tractogram` streamlines are in RAS+ and mm space, we will compute From 69e20f3c3ad41f9bdb971ee829e86b9b63458171 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 29 Nov 2015 01:02:28 -0500 Subject: [PATCH 42/54] Added a script to generate standard test object. --- .../streamlines/tests/data/gen_standard.py | 79 ++++++++++++++++++ .../streamlines/tests/data/standard.LPS.trk | Bin 0 -> 5800 bytes nibabel/streamlines/tests/data/standard.trk | Bin 0 -> 5800 bytes nibabel/streamlines/tests/test_streamlines.py | 5 -- nibabel/streamlines/tests/test_trk.py | 20 +++-- nibabel/streamlines/trk.py | 29 ++++--- 6 files changed, 107 insertions(+), 26 deletions(-) create mode 100644 nibabel/streamlines/tests/data/gen_standard.py create mode 100644 nibabel/streamlines/tests/data/standard.LPS.trk create mode 100644 nibabel/streamlines/tests/data/standard.trk diff --git a/nibabel/streamlines/tests/data/gen_standard.py b/nibabel/streamlines/tests/data/gen_standard.py new file mode 100644 index 0000000000..63b4173602 --- /dev/null +++ b/nibabel/streamlines/tests/data/gen_standard.py @@ -0,0 +1,79 @@ +import numpy as np +import nibabel as nib + +from nibabel.streamlines import FORMATS +from nibabel.streamlines.header import Field + + +def mark_the_spot(mask): + """ Marks every nonzero voxel using streamlines to form a 3D 'X' inside. + + Generates streamlines forming a 3D 'X' inside every nonzero voxel. + + Parameters + ---------- + mask : ndarray + Mask containing the spots to be marked. + + Returns + ------- + list of ndarrays + All streamlines needed to mark every nonzero voxel in the `mask`. + """ + def _gen_straight_streamline(start, end, steps=3): + coords = [] + for s, e in zip(start, end): + coords.append(np.linspace(s, e, steps)) + + return np.array(coords).T + + # Generate a 3D 'X' template fitting inside the voxel centered at (0,0,0). + X = [] + X.append(_gen_straight_streamline((-0.5, -0.5, -0.5), (0.5, 0.5, 0.5))) + X.append(_gen_straight_streamline((-0.5, 0.5, -0.5), (0.5, -0.5, 0.5))) + X.append(_gen_straight_streamline((-0.5, 0.5, 0.5), (0.5, -0.5, -0.5))) + X.append(_gen_straight_streamline((-0.5, -0.5, 0.5), (0.5, 0.5, -0.5))) + + # Get the coordinates of voxels 'on' in the mask. + coords = np.array(zip(*np.where(mask))) + + streamlines = [] + for c in coords: + for line in X: + streamlines.append((line + c) * voxel_size) + + return streamlines + + +rng = np.random.RandomState(42) + +width = 4 # Coronal +height = 5 # Sagittal +depth = 7 # Axial + +voxel_size = np.array((1., 3., 2.)) + +# Generate a random mask with voxel order RAS+. +mask = rng.rand(width, height, depth) > 0.8 +mask = (255*mask).astype(np.uint8) + +# Build tractogram +streamlines = mark_the_spot(mask) +tractogram = nib.streamlines.Tractogram(streamlines) + +# Build header +affine = np.eye(4) +affine[range(3), range(3)] = voxel_size +header = {Field.DIMENSIONS: (width, height, depth), + Field.VOXEL_SIZES: voxel_size, + Field.VOXEL_TO_RASMM: affine, + Field.VOXEL_ORDER: 'RAS'} + +# Save the standard mask. +nii = nib.Nifti1Image(mask, affine=affine) +nib.save(nii, "standard.nii.gz") + +# Save the standard tractogram in every available file format. +for ext, cls in FORMATS.items(): + tfile = cls(tractogram, header) + nib.streamlines.save(tfile, "standard" + ext) diff --git a/nibabel/streamlines/tests/data/standard.LPS.trk b/nibabel/streamlines/tests/data/standard.LPS.trk new file mode 100644 index 0000000000000000000000000000000000000000..ebda71bdb80c27dec2a8e61f22585ca9f56fff41 GIT binary patch literal 5800 zcmeH}Ka$lv5X7hWQ>1$YhlJb!y&{C*6%Y{m2pbam2zvyM;-hRx$Pwt)Xjc8k{Y`g|?C1Kvf4zMAY;WwXy)(1zqgiX1@Ol~l^a=1;<rB*FQnL!u9*L8UDlAn%VsY6Wmu^KQM2&_Y3o_O|G4Atk2w5=(*3- z^SavCvadP!o?hddbl3Aty}m9u&z-oDPxU%ylhb~cbDzW^3(t=n|IYK{eYs4o*A*On zGPj&?CVF01`##5gfOh0-qUU+1UU2lu9QUGM`f^=y ztj}EJQ$6>&^{qJfIp?n9{K&1+FW0IU9DOslO25i!A71C@$3B8{AJ0eT9$Fkr^WQIj zABH)z#P5h)Ut=qCy}nI7uRGyva$a}j+-GvC$KMCDFy9^K@!jAa*W6!lDaZQEZ6{pr zx8mGq=2PE4&~3O3*N#hlx#s?YOS#;K*VUZn^SUGFKDYU}7ru{}8!qnIc>mg1b8o@L ze2=&cSMM)zhBK|WUETw)J90ZTpW}L6XHff$Z*aV4ZjpS-X}?}K^99Fznd3VY{a!J* zk#9};Dvr9DTMIpMP4v93_EnB~!}kI6$k#AeU#=^T^_h!&s^>oQz4AKGr=0stj%yqJ z@;k)7jC&Z1^xHQIj`f*aPB;@iuk(D_H*;|=@tur(tLQOb#nC5otK`$?(L~SdX1?Iu z$Ip*B?xD5g&TRbut4GW?%$X(nAjfsZu|9J~HtD&~t#8G-&*apX-`&V}mwdU7e9Ez2 z^Ld@;%UpanqTk4;ocqjts>kmd*E~OFS;q6A9@n9#9Qn*G>;)Q`%Y9Uw`^j?W`c?WMw-$QtbL(4itZze)=S0swFy}dQkE5S!b3N|$HUE&fUZ1&k z!r9dGx+CX4=N$XwAh$i^8ytNyw>1ciK9n=j^SUFK&xPDBdXJoAbWYBh>x%2=kT?Sy z`9jZq(zz6z`BH-0AEDRpi6a~RUOC4gw+hE~#c?ll zD>cUR4Lx#A^t|rKxz9X5uB|!8??K}ngU?dmE3PY!^_knqBA@EH&#iC8xz9QGi`p3E zc6pz<-f)F)aP-X_*^1LXyza=k&&=1gfKl(EAp`Clt?{49^)>&HxL%(*)>fQNJ+C`* z?sLvX4&JrjzAwJP?FF};aN4ie*;5~(*Y79&?&vuz%nbUk; nchqBj=J;C*-z(>1=gT);^k*f&<3e_P@RYyjOYa_iifYg5T%4`d0Zp zuE68%^Dhnh71TS#_xEP_)2$%E^RplCd%X9p4DRJSkT}X@t}XP)ndmWHaqg3R3B&hW z@n@Cu?KoDwo}-kxtml0foLm?0n>qJM^W6sbfa92NyA7`QEjapQZoA-2^t^BF`x^JR z%vVH@^Qmue^vPV2`IOs4&-=!F!MV@4F4PL&e{me#6OPrF`HEwI=2+u%pn9yG=Iebk z=RW88RPT;H-d}Q6*2{dwv6i`tTFe)EK+ zZbOahP(5-cdfqp4?sJ}xZ0OyobIeyok9{kSwanFp9=RfV-naG*j&pF^3$6izGd^Qq zj<7?|j(o>{Lnvq55@(oY#qGtf3-$W8;0n$u=eg_gyZelYn)>d&Z|2BmF6L7`*BtYC z->m08gJT|g{O&&RJ{{*$LHf53~sf^|Iawy5%GIOrspVSj^`#0wIyfCcfq;OHAlaO zAlJ&^dY|Y8=le6)E;wCB^jENaP9xuOoH^Eld&Y4?ux7{oC`Zm(MZ|O8aLjX=Gv4zB zx669HZ{~LNy~aI2yLDzf_tava`Ub~yTUy5^^C_q6@V+r$aPBki2f6;NV$bVxm-%Y) zU2*KsT$TAk&wb9mD~|J(Uw*=;CjFKdf%+) zKFJq8T*u^V*IWS~^XYoCZ*Y0N voxel affine = np.dot(np.linalg.inv(self.header[Field.VOXEL_TO_RASMM]), affine) @@ -330,6 +322,13 @@ def write(self, tractogram): M = nib.orientations.inv_ornt_aff(ornt, self.header[Field.DIMENSIONS]) affine = np.dot(M, affine) + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas `Tractogram` streamlines assume (0,0,0) is the + # center of the voxel. Thus, streamlines are shifted of half a voxel. + offset = np.eye(4) + offset[:-1, -1] += 0.5 + affine = np.dot(offset, affine) + # Finally send the streamlines in mm space. # voxel -> voxelmm scale = np.eye(4) @@ -580,6 +579,13 @@ def load(cls, fileobj, lazy_load=False): scale[range(3), range(3)] /= trk_reader.header[Field.VOXEL_SIZES] affine = np.dot(scale, affine) + # TrackVis considers coordinate (0,0,0) to be the corner of the voxel + # whereas streamlines returned assume (0,0,0) to be the center of the + # voxel. Thus, streamlines are shifted of half a voxel. + offset = np.eye(4) + offset[:-1, -1] -= 0.5 + affine = np.dot(offset, affine) + # If the voxel order implied by the affine does not match the voxel # order in the TRK header, change the orientation. # voxel (header) -> voxel (affine) @@ -595,13 +601,6 @@ def load(cls, fileobj, lazy_load=False): # voxel -> rasmm affine = np.dot(trk_reader.header[Field.VOXEL_TO_RASMM], affine) - # TrackVis considers coordinate (0,0,0) to be the corner of the voxel - # whereas streamlines returned assume (0,0,0) to be the center of the - # voxel. Thus, streamlines are shifted of half a voxel. - offset = np.eye(4) - offset[:-1, -1] -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. - affine = np.dot(offset, affine) - # Find scalars and properties name data_per_point_slice = {} From ed12031a9b7a83a925036f82e5dff78dd033ac66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 29 Nov 2015 01:04:25 -0500 Subject: [PATCH 43/54] Added the mask used by the standard test object --- nibabel/streamlines/tests/data/standard.nii.gz | Bin 0 -> 143 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 nibabel/streamlines/tests/data/standard.nii.gz diff --git a/nibabel/streamlines/tests/data/standard.nii.gz b/nibabel/streamlines/tests/data/standard.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..98bb31a77860a668b865d16f4605b8612eba512c GIT binary patch literal 143 zcmV;A0C4{wiwFoOlv-8-|8sO Date: Sun, 29 Nov 2015 02:05:34 -0500 Subject: [PATCH 44/54] Added support for Python 3 --- nibabel/streamlines/compact_list.py | 2 +- nibabel/streamlines/tests/test_compact_list.py | 12 ++++++------ nibabel/streamlines/trk.py | 12 +++++++----- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index fb98536bab..785e363490 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -110,7 +110,7 @@ def extend(self, elements): lengths = elements._lengths else: self._data = np.concatenate([self._data] + list(elements), axis=0) - lengths = map(len, elements) + lengths = list(map(len, elements)) idx = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 self._lengths.extend(lengths) diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 71dac112b7..34b1ceb65d 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -17,7 +17,7 @@ class TestCompactList(unittest.TestCase): def setUp(self): rng = np.random.RandomState(42) self.data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - self.lengths = map(len, self.data) + self.lengths = list(map(len, self.data)) self.clist = CompactList(self.data) def test_creating_empty_compactlist(self): @@ -31,7 +31,7 @@ def test_creating_empty_compactlist(self): def test_creating_compactlist_from_list(self): rng = np.random.RandomState(42) data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = map(len, data) + lengths = list(map(len, data)) clist = CompactList(data) assert_equal(len(clist), len(data)) @@ -54,7 +54,7 @@ def test_creating_compactlist_from_list(self): def test_creating_compactlist_from_generator(self): rng = np.random.RandomState(42) data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = map(len, data) + lengths = list(map(len, data)) gen = (e for e in data) clist = CompactList(gen) @@ -78,7 +78,7 @@ def test_creating_compactlist_from_generator(self): def test_creating_compactlist_from_compact_list(self): rng = np.random.RandomState(42) data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = map(len, data) + lengths = list(map(len, data)) clist = CompactList(data) clist2 = CompactList(clist) @@ -152,7 +152,7 @@ def test_compactlist_extend(self): rng = np.random.RandomState(1234) shape = self.clist.shape new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(5)] - lengths = map(len, new_data) + lengths = list(map(len, new_data)) clist.extend(new_data) assert_equal(len(clist), len(self.clist)+len(new_data)) assert_array_equal(clist._offsets[-len(new_data):], @@ -187,7 +187,7 @@ def test_compactlist_getitem(self): assert_array_equal(self.clist[i], e) # Get multiple items (this will create a view). - clist_view = self.clist[range(len(self.clist))] + clist_view = self.clist[list(range(len(self.clist)))] assert_true(clist_view is not self.clist) assert_true(clist_view._data is self.clist._data) assert_true(clist_view._offsets is not self.clist._offsets) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index cdff6d6624..251d7681c8 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -13,6 +13,7 @@ from nibabel.affines import apply_affine from nibabel.openers import Opener +from nibabel.py3k import asbytes, asstr from nibabel.volumeutils import (native_code, swapped_code) from .compact_list import CompactList @@ -276,7 +277,7 @@ def write(self, tractogram): property_name = k if nb_values > 1: # Use the last to bytes of the name to store the nb of values associated to this data_for_streamline. - property_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(nb_values, dtype=np.int8).tostring() + property_name = asbytes(k[:18].ljust(18, '\x00')) + b'\x00' + np.array(nb_values, dtype=np.int8).tostring() self.header['property_name'][i] = property_name @@ -298,7 +299,7 @@ def write(self, tractogram): scalar_name = k if nb_values > 1: # Use the last to bytes of the name to store the nb of values associated to this data_for_streamline. - scalar_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(nb_values, dtype=np.int8).tostring() + scalar_name = asbytes(k[:18].ljust(18, '\x00')) + b'\x00' + np.array(nb_values, dtype=np.int8).tostring() self.header['scalar_name'][i] = scalar_name @@ -314,7 +315,7 @@ def write(self, tractogram): # If the voxel order implied by the affine does not match the voxel # order in the TRK header, change the orientation. # voxel (affine) -> voxel (header) - header_ornt = self.header[Field.VOXEL_ORDER] + header_ornt = asstr(self.header[Field.VOXEL_ORDER]) affine_ornt = "".join(nib.orientations.aff2axcodes(self.header[Field.VOXEL_TO_RASMM])) header_ornt = nib.orientations.axcodes2ornt(header_ornt) affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) @@ -589,7 +590,7 @@ def load(cls, fileobj, lazy_load=False): # If the voxel order implied by the affine does not match the voxel # order in the TRK header, change the orientation. # voxel (header) -> voxel (affine) - header_ornt = trk_reader.header[Field.VOXEL_ORDER] + header_ornt = asstr(trk_reader.header[Field.VOXEL_ORDER]) affine_ornt = "".join(nib.orientations.aff2axcodes(trk_reader.header[Field.VOXEL_TO_RASMM])) header_ornt = nib.orientations.axcodes2ornt(header_ornt) affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) @@ -601,12 +602,12 @@ def load(cls, fileobj, lazy_load=False): # voxel -> rasmm affine = np.dot(trk_reader.header[Field.VOXEL_TO_RASMM], affine) - # Find scalars and properties name data_per_point_slice = {} if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: cpt = 0 for scalar_name in trk_reader.header['scalar_name']: + scalar_name = asstr(scalar_name) if len(scalar_name) == 0: continue @@ -626,6 +627,7 @@ def load(cls, fileobj, lazy_load=False): if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: cpt = 0 for property_name in trk_reader.header['property_name']: + property_name = asstr(property_name) if len(property_name) == 0: continue From 39ad83fe2b2c193bc1258cd7433acfe552f5303e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 29 Nov 2015 02:25:50 -0500 Subject: [PATCH 45/54] Revert changes made to __init__.py --- nibabel/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nibabel/__init__.py b/nibabel/__init__.py index d40a388487..779f6e8587 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -23,9 +23,9 @@ img3 = nib.load('spm_file.img') data = img1.get_data() - affine = img1.get_affine() + affine = img1.affine - print img1 + print(img1) nib.save(img1, 'my_file_copy.nii.gz') @@ -63,7 +63,6 @@ apply_orientation, aff2axcodes) from .imageclasses import class_map, ext_map, all_image_classes from . import trackvis -from .streamlines import Tractogram from . import mriutils # be friendly on systems with ancient numpy -- no tests, but at least From 3f5e60841d501d6a668d681dd6df7ad63ec2e642 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 29 Nov 2015 02:57:15 -0500 Subject: [PATCH 46/54] Removed streamlines benchmark for now. --- nibabel/benchmarks/bench_streamlines.py | 136 ------------------------ 1 file changed, 136 deletions(-) delete mode 100644 nibabel/benchmarks/bench_streamlines.py diff --git a/nibabel/benchmarks/bench_streamlines.py b/nibabel/benchmarks/bench_streamlines.py deleted file mode 100644 index 3d0da178c1..0000000000 --- a/nibabel/benchmarks/bench_streamlines.py +++ /dev/null @@ -1,136 +0,0 @@ -""" Benchmarks for load and save of streamlines - -Run benchmarks with:: - - import nibabel as nib - nib.bench() - -If you have doctests enabled by default in nose (with a noserc file or -environment variable), and you have a numpy version <= 1.6.1, this will also run -the doctests, let's hope they pass. - -Run this benchmark with: - - nosetests -s --match '(?:^|[\\b_\\.//-])[Bb]ench' /path/to/bench_streamlines.py -""" -from __future__ import division, print_function - -import os -import numpy as np - -from nibabel.externals.six import BytesIO -from nibabel.externals.six.moves import zip - -from nibabel.testing import assert_arrays_equal - -from numpy.testing import assert_array_equal -from nibabel.streamlines.base_format import Streamlines -from nibabel.streamlines import TrkFile - -import nibabel as nib -import nibabel.trackvis as tv - -from numpy.testing import measure - - -def bench_load_trk(): - NB_STREAMLINES = 1000 - NB_POINTS = 1000 - points = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)] - repeat = 20 - - trk_file = BytesIO() - #trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES)) - #tv.write(trk_file, trk) - streamlines = Streamlines(points) - TrkFile.save(streamlines, trk_file) - - # from pycallgraph import PyCallGraph - # from pycallgraph.output import GraphvizOutput - - # with PyCallGraph(output=GraphvizOutput()): - # #nib.streamlines.load(trk_file, ref=None, lazy_load=False) - - mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, ref=None, lazy_load=False)', repeat) - print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) - - mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file, points_space="voxel")', repeat) - print("Old: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) - print("Speedup of %2f" % (mtime_old/mtime_new)) - - # Points and scalars - scalars = [np.random.rand(NB_POINTS, 10).astype('float32') for i in range(NB_STREAMLINES)] - - trk_file = BytesIO() - #trk = list(zip(points, scalars, [None]*NB_STREAMLINES)) - #tv.write(trk_file, trk) - streamlines = Streamlines(points, scalars) - TrkFile.save(streamlines, trk_file) - - mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, ref=None, lazy_load=False)', repeat) - print("New: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new)) - - mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file, points_space="voxel")', repeat) - print("Old: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old)) - print("Speedup of %2f" % (mtime_old/mtime_new)) - - -def bench_save_trk(): - NB_STREAMLINES = 100 - NB_POINTS = 1000 - points = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)] - repeat = 10 - - # Only points - streamlines = Streamlines(points) - trk_file_new = BytesIO() - - mtime_new = measure('trk_file_new.seek(0, os.SEEK_SET); TrkFile.save(streamlines, trk_file_new)', repeat) - print("\nNew: Saved %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) - - trk_file_old = BytesIO() - trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES)) - mtime_old = measure('trk_file_old.seek(0, os.SEEK_SET); tv.write(trk_file_old, trk)', repeat) - print("Old: Saved %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) - print("Speedup of %2f" % (mtime_old/mtime_new)) - - trk_file_new.seek(0, os.SEEK_SET) - trk_file_old.seek(0, os.SEEK_SET) - streams, hdr = tv.read(trk_file_old) - - for pts, A in zip(points, streams): - assert_array_equal(pts, A[0]) - - trk = nib.streamlines.load(trk_file_new, ref=None, lazy_load=False) - assert_arrays_equal(points, trk.points) - - # Points and scalars - scalars = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)] - streamlines = Streamlines(points, scalars=scalars) - trk_file_new = BytesIO() - - mtime_new = measure('trk_file_new.seek(0, os.SEEK_SET); TrkFile.save(streamlines, trk_file_new)', repeat) - print("New: Saved %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new)) - - trk_file_old = BytesIO() - trk = list(zip(points, scalars, [None]*NB_STREAMLINES)) - mtime_old = measure('trk_file_old.seek(0, os.SEEK_SET); tv.write(trk_file_old, trk)', repeat) - print("Old: Saved %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old)) - print("Speedup of %2f" % (mtime_old/mtime_new)) - - trk_file_new.seek(0, os.SEEK_SET) - trk_file_old.seek(0, os.SEEK_SET) - streams, hdr = tv.read(trk_file_old) - - for pts, scal, A in zip(points, scalars, streams): - assert_array_equal(pts, A[0]) - assert_array_equal(scal, A[1]) - - trk = nib.streamlines.load(trk_file_new, ref=None, lazy_load=False) - - assert_arrays_equal(points, trk.points) - assert_arrays_equal(scalars, trk.scalars) - - -if __name__ == '__main__': - bench_save_trk() From 4e3c5450c152e1d1992cb49af73a592f9f1d0a5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 29 Nov 2015 14:55:00 -0500 Subject: [PATCH 47/54] Python2.6 compatibility fix. Thanks @effigies --- nibabel/streamlines/trk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 251d7681c8..ede5f81124 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -647,8 +647,8 @@ def load(cls, fileobj, lazy_load=False): if lazy_load: def _read(): for pts, scals, props in trk_reader: - data_for_points = {k: scals[:, v] for k, v in data_per_point_slice.items()} - data_for_streamline = {k: props[v] for k, v in data_per_streamline_slice.items()} + data_for_points = dict((k, scals[:, v]) for k, v in data_per_point_slice.items()) + data_for_streamline = dict((k, props[v]) for k, v in data_per_streamline_slice.items()) yield TractogramItem(pts, data_for_streamline, data_for_points) tractogram = LazyTractogram.create_from(_read) From 4ebbdc2372eab73b165ac1d11bce5286fa66b23c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 29 Nov 2015 16:07:35 -0500 Subject: [PATCH 48/54] Added more unit tests to increase coverage. --- nibabel/streamlines/base_format.py | 12 --- nibabel/streamlines/compact_list.py | 16 +-- .../streamlines/tests/test_compact_list.py | 26 +++++ nibabel/streamlines/tests/test_tractogram.py | 98 ++++++++++++++++++- nibabel/streamlines/tests/test_trk.py | 3 +- nibabel/streamlines/tractogram.py | 10 +- nibabel/streamlines/tractogram_file.py | 12 +++ nibabel/streamlines/trk.py | 2 +- 8 files changed, 153 insertions(+), 26 deletions(-) delete mode 100644 nibabel/streamlines/base_format.py diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py deleted file mode 100644 index e694e7e0c4..0000000000 --- a/nibabel/streamlines/base_format.py +++ /dev/null @@ -1,12 +0,0 @@ - - -class HeaderWarning(Warning): - pass - - -class HeaderError(Exception): - pass - - -class DataError(Exception): - pass diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 785e363490..ffe2e9be9d 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -5,6 +5,9 @@ class CompactList(object): """ Class for compacting list of ndarrays with matching shape except for the first dimension. """ + + BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. + def __init__(self, iterable=None): """ Parameters @@ -31,20 +34,19 @@ def __init__(self, iterable=None): elif iterable is not None: # Initialize the `CompactList` object from iterable's item. - BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. - offset = 0 for i, e in enumerate(iterable): e = np.asarray(e) if i == 0: - self._data = np.empty((BUFFER_SIZE,) + e.shape[1:], - dtype=e.dtype) + new_shape = (CompactList.BUFFER_SIZE,) + e.shape[1:] + self._data = np.empty(new_shape, dtype=e.dtype) end = offset + len(e) if end >= len(self._data): - # Resize is needed (at least `len(e)` items will be added). - self._data.resize((len(self._data) + len(e)+BUFFER_SIZE,) - + self.shape) + # Resize needed, adding `len(e)` new items plus some buffer. + nb_points = len(self._data) + nb_points += len(e) + CompactList.BUFFER_SIZE + self._data.resize((nb_points,) + self.shape) self._offsets.append(offset) self._lengths.append(len(e)) diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 34b1ceb65d..7a1c25000f 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -51,6 +51,20 @@ def test_creating_compactlist_from_list(self): assert_true(clist._data is None) assert_true(clist.shape is None) + # Force CompactList constructor to use buffering. + old_buffer_size = CompactList.BUFFER_SIZE + CompactList.BUFFER_SIZE = 1 + clist = CompactList(data) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + CompactList.BUFFER_SIZE = old_buffer_size + def test_creating_compactlist_from_generator(self): rng = np.random.RandomState(42) data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] @@ -95,6 +109,11 @@ def test_compactlist_iter(self): for e, d in zip(self.clist, self.data): assert_array_equal(e, d) + # Try iterate through a corrupted CompactList object. + clist = self.clist.copy() + clist._lengths = clist._lengths[::2] + assert_raises(ValueError, list, clist) + def test_compactlist_copy(self): clist = self.clist.copy() assert_array_equal(clist._data, self.clist._data) @@ -219,6 +238,13 @@ def test_compactlist_getitem(self): assert_array_equal(clist_view[1], self.clist[2]) assert_array_equal(clist_view[2], self.clist[4]) + # Test invalid indexing + assert_raises(TypeError, self.clist.__getitem__, 'abc') + + def test_compactlist_repr(self): + # Test that calling repr on a CompactList object is not falling. + repr(self.clist) + def test_save_and_load_compact_list(): diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 64b0e150ab..23e4440515 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -127,7 +127,7 @@ def test_tractogram_creation(self): assert_raises(ValueError, Tractogram, self.streamlines, data_per_point=data_per_point) - def test_tractogram_getter(self): + def test_tractogram_getitem(self): # Tractogram with only streamlines tractogram = Tractogram(streamlines=self.streamlines) @@ -200,6 +200,43 @@ def test_tractogram_add_new_data(self): assert_arrays_equal(r_tractogram.data_per_point['colors'], self.colors[::-1]) + def test_tractogram_copy(self): + # Create a tractogram with streamlines and other data. + tractogram1 = Tractogram( + self.streamlines, + data_per_streamline={'mean_curvature': self.mean_curvature, + 'mean_color': self.mean_color}, + data_per_point={'colors': self.colors}) + + # Create a copy of the tractogram. + tractogram2 = tractogram1.copy() + + # Check we copied the data and not simply created new references. + assert_true(tractogram1 is not tractogram2) + assert_true(tractogram1.streamlines is not tractogram2.streamlines) + assert_true(tractogram1.data_per_streamline + is not tractogram2.data_per_streamline) + assert_true(tractogram1.data_per_streamline['mean_curvature'] + is not tractogram2.data_per_streamline['mean_curvature']) + assert_true(tractogram1.data_per_streamline['mean_color'] + is not tractogram2.data_per_streamline['mean_color']) + assert_true(tractogram1.data_per_point + is not tractogram2.data_per_point) + assert_true(tractogram1.data_per_point['colors'] + is not tractogram2.data_per_point['colors']) + + # Check the data are the equivalent. + assert_true(isiterable(tractogram2)) + assert_equal(len(tractogram1), len(tractogram2)) + assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) + assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) + assert_arrays_equal(tractogram1.data_per_streamline['mean_curvature'], + tractogram2.data_per_streamline['mean_curvature']) + assert_arrays_equal(tractogram1.data_per_streamline['mean_color'], + tractogram2.data_per_streamline['mean_color']) + assert_arrays_equal(tractogram1.data_per_point['colors'], + tractogram2.data_per_point['colors']) + class TestLazyTractogram(unittest.TestCase): @@ -303,7 +340,10 @@ def _data_gen(): assert_arrays_equal(tractogram.data_per_point['colors'], self.colors) - def test_lazy_tractogram_indexing(self): + # Creating a LazyTractogram from not a corouting should raise an error. + assert_raises(TypeError, LazyTractogram.create_from, _data_gen()) + + def test_lazy_tractogram_getitem(self): streamlines = lambda: (x for x in self.streamlines) data_per_point = {"colors": self.colors_func} data_per_streamline = {'mean_curv': self.mean_curvature_func, @@ -382,3 +422,57 @@ def test_lazy_tractogram_apply_affine(self): assert_equal(len(tractogram), len(self.streamlines)) for s1, s2 in zip(tractogram.streamlines, self.streamlines): assert_array_almost_equal(s1, s2*scaling) + + def test_lazy_tractogram_copy(self): + # Create tractogram with streamlines and other data + streamlines = lambda: (x for x in self.streamlines) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} + + tractogram1 = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + assert_true(isiterable(tractogram1)) # Implicitly set _nb_streamlines. + + # Create a copy of the tractogram. + tractogram2 = tractogram1.copy() + + # Check we copied the data and not simply created new references. + assert_true(tractogram1 is not tractogram2) + + # When copying LazyTractogram, coroutines generating streamlines should + # be the same. + assert_true(tractogram1._streamlines is tractogram2._streamlines) + + # Copying LazyTractogram, creates new internal LazyDict objects, + # but coroutines contained in it should be the same. + assert_true(tractogram1._data_per_streamline + is not tractogram2._data_per_streamline) + assert_true(tractogram1.data_per_streamline.store['mean_curv'] + is tractogram2.data_per_streamline.store['mean_curv']) + assert_true(tractogram1.data_per_streamline.store['mean_color'] + is tractogram2.data_per_streamline.store['mean_color']) + assert_true(tractogram1._data_per_point + is not tractogram2._data_per_point) + assert_true(tractogram1.data_per_point.store['colors'] + is tractogram2.data_per_point.store['colors']) + + # The affine should be a copy. + assert_true(tractogram1._affine_to_apply + is not tractogram2._affine_to_apply) + assert_array_equal(tractogram1._affine_to_apply, + tractogram2._affine_to_apply) + + # Check the data are the equivalent. + assert_equal(tractogram1._nb_streamlines, tractogram2._nb_streamlines) + assert_true(isiterable(tractogram2)) + assert_equal(len(tractogram1), len(tractogram2)) + assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) + assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) + assert_arrays_equal(tractogram1.data_per_streamline['mean_curv'], + tractogram2.data_per_streamline['mean_curv']) + assert_arrays_equal(tractogram1.data_per_streamline['mean_color'], + tractogram2.data_per_streamline['mean_color']) + assert_arrays_equal(tractogram1.data_per_point['colors'], + tractogram2.data_per_point['colors']) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 80853ea80a..ba3f3d3291 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -9,9 +9,8 @@ from nose.tools import assert_equal, assert_raises, assert_true from .test_tractogram import assert_tractogram_equal -from .. import base_format from ..tractogram import Tractogram, LazyTractogram -from ..base_format import DataError, HeaderError, HeaderWarning#, UsageWarning +from ..tractogram_file import DataError, HeaderError, HeaderWarning from .. import trk as trk_module from ..trk import TrkFile, header_2_dtype diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 2ae433fd4b..29fb975eee 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -202,7 +202,12 @@ class LazyDict(collections.MutableMapping): def __init__(self, *args, **kwargs): self.store = dict() - self.update(dict(*args, **kwargs)) # Use update to set keys. + + # Use update to set keys. + if len(args) == 1 and isinstance(args[0], LazyTractogram.LazyDict): + self.update(dict(args[0].store.items())) + else: + self.update(dict(*args, **kwargs)) def __getitem__(self, key): return self.store[key]() @@ -377,8 +382,9 @@ def copy(self): tractogram = LazyTractogram(self._streamlines, self._data_per_streamline, self._data_per_point) - tractogram.nb_streamlines = self.nb_streamlines + tractogram._nb_streamlines = self._nb_streamlines tractogram._data = self._data + tractogram._affine_to_apply = self._affine_to_apply.copy() return tractogram def apply_affine(self, affine): diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index feea0e9822..8bb1aa41ea 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -3,6 +3,18 @@ from .header import TractogramHeader +class HeaderWarning(Warning): + pass + + +class HeaderError(Exception): + pass + + +class DataError(Exception): + pass + + class abstractclassmethod(classmethod): __isabstractmethod__ = True diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index ede5f81124..9b043ff821 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -18,7 +18,7 @@ from .compact_list import CompactList from .tractogram_file import TractogramFile -from .base_format import DataError, HeaderError, HeaderWarning +from .tractogram_file import DataError, HeaderError, HeaderWarning from .tractogram import TractogramItem, Tractogram, LazyTractogram from .header import Field From 35e53302a7cf3cf5ffdb33df1785114fdaededb5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 1 Dec 2015 14:39:50 -0500 Subject: [PATCH 49/54] Changed function name 'isiterable' to 'check_iteration' --- nibabel/streamlines/tests/test_streamlines.py | 2 +- nibabel/streamlines/tests/test_tractogram.py | 28 +++++++++---------- nibabel/streamlines/tests/test_trk.py | 2 +- nibabel/testing/__init__.py | 4 ++- 4 files changed, 19 insertions(+), 17 deletions(-) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 645b3ac4cb..f76e8ab581 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -9,7 +9,7 @@ from nibabel.externals.six import BytesIO from nibabel.testing import clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal, isiterable +from nibabel.testing import assert_arrays_equal, check_iteration from nose.tools import assert_equal, assert_raises, assert_true, assert_false from .test_tractogram import assert_tractogram_equal diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 23e4440515..953765bc9b 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -2,7 +2,7 @@ import numpy as np import warnings -from nibabel.testing import assert_arrays_equal, isiterable +from nibabel.testing import assert_arrays_equal, check_iteration from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal, assert_array_almost_equal @@ -14,7 +14,7 @@ def assert_tractogram_equal(t1, t2): - assert_true(isiterable(t1)) + assert_true(check_iteration(t1)) assert_equal(len(t1), len(t2)) assert_arrays_equal(t1.streamlines, t2.streamlines) @@ -81,7 +81,7 @@ def test_tractogram_creation(self): assert_arrays_equal(tractogram.streamlines, []) assert_equal(tractogram.data_per_streamline, {}) assert_equal(tractogram.data_per_point, {}) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) # Create a tractogram with only streamlines tractogram = Tractogram(streamlines=self.streamlines) @@ -89,7 +89,7 @@ def test_tractogram_creation(self): assert_arrays_equal(tractogram.streamlines, self.streamlines) assert_equal(tractogram.data_per_streamline, {}) assert_equal(tractogram.data_per_point, {}) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) # Create a tractogram with streamlines and other data. tractogram = Tractogram( @@ -107,7 +107,7 @@ def test_tractogram_creation(self): assert_arrays_equal(tractogram.data_per_point['colors'], self.colors) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) # Inconsistent number of scalars between streamlines wrong_data = [[(1, 0, 0)]*1, @@ -226,7 +226,7 @@ def test_tractogram_copy(self): is not tractogram2.data_per_point['colors']) # Check the data are the equivalent. - assert_true(isiterable(tractogram2)) + assert_true(check_iteration(tractogram2)) assert_equal(len(tractogram1), len(tractogram2)) assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) @@ -280,7 +280,7 @@ def test_lazy_tractogram_creation(self): # Empty `LazyTractogram` tractogram = LazyTractogram() - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) assert_equal(len(tractogram), 0) assert_arrays_equal(tractogram.streamlines, []) assert_equal(tractogram.data_per_point, {}) @@ -296,7 +296,7 @@ def test_lazy_tractogram_creation(self): data_per_streamline=data_per_streamline, data_per_point=data_per_point) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) assert_equal(len(tractogram), self.nb_streamlines) # Coroutines get re-called and creates new iterators. @@ -314,7 +314,7 @@ def test_lazy_tractogram_create_from(self): _empty_data_gen = lambda: iter([]) tractogram = LazyTractogram.create_from(_empty_data_gen) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) assert_equal(len(tractogram), 0) assert_arrays_equal(tractogram.streamlines, []) assert_equal(tractogram.data_per_point, {}) @@ -330,7 +330,7 @@ def _data_gen(): yield TractogramItem(d[0], data_for_streamline, data_for_points) tractogram = LazyTractogram.create_from(_data_gen) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) assert_equal(len(tractogram), self.nb_streamlines) assert_arrays_equal(tractogram.streamlines, self.streamlines) assert_arrays_equal(tractogram.data_per_streamline['mean_curv'], @@ -397,7 +397,7 @@ def test_lazy_tractogram_len(self): data_per_point=data_per_point) assert_true(tractogram._nb_streamlines is None) - isiterable(tractogram) # Force to iterate through all streamlines. + check_iteration(tractogram) # Force iteration through tractogram. assert_equal(tractogram._nb_streamlines, len(self.streamlines)) # This should *not* produce a warning. assert_equal(len(tractogram), len(self.streamlines)) @@ -418,7 +418,7 @@ def test_lazy_tractogram_apply_affine(self): data_per_point=data_per_point) tractogram.apply_affine(affine) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) assert_equal(len(tractogram), len(self.streamlines)) for s1, s2 in zip(tractogram.streamlines, self.streamlines): assert_array_almost_equal(s1, s2*scaling) @@ -433,7 +433,7 @@ def test_lazy_tractogram_copy(self): tractogram1 = LazyTractogram(streamlines, data_per_streamline=data_per_streamline, data_per_point=data_per_point) - assert_true(isiterable(tractogram1)) # Implicitly set _nb_streamlines. + assert_true(check_iteration(tractogram1)) # Implicitly set _nb_streamlines. # Create a copy of the tractogram. tractogram2 = tractogram1.copy() @@ -466,7 +466,7 @@ def test_lazy_tractogram_copy(self): # Check the data are the equivalent. assert_equal(tractogram1._nb_streamlines, tractogram2._nb_streamlines) - assert_true(isiterable(tractogram2)) + assert_true(check_iteration(tractogram2)) assert_equal(len(tractogram1), len(tractogram2)) assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index ba3f3d3291..f689e53ef8 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -5,7 +5,7 @@ from nibabel.externals.six import BytesIO from nibabel.testing import suppress_warnings, clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal, isiterable +from nibabel.testing import assert_arrays_equal, check_iteration from nose.tools import assert_equal, assert_raises, assert_true from .test_tractogram import assert_tractogram_equal diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index d29aea2ca3..8fdbc74f08 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -11,6 +11,7 @@ import sys import warnings +import collections from os.path import dirname, abspath, join as pjoin from nibabel.externals.six.moves import zip_longest @@ -59,7 +60,8 @@ def assert_allclose_safely(a, b, match_nans=True): assert_true(np.allclose(a, b)) -def isiterable(iterable): +def check_iteration(iterable): + """ Checks that an object can be iterated through without errors. """ try: for _ in iterable: pass From 753e18127ad2ceb8ea45b9a38b21fbe3429e3a3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 1 Dec 2015 14:55:53 -0500 Subject: [PATCH 50/54] Support upper case file extension --- nibabel/streamlines/__init__.py | 7 +-- nibabel/streamlines/tests/test_streamlines.py | 43 +++++++++++-------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index abf9ca27f3..6cb4bdd87a 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,3 +1,6 @@ +import os +from ..externals.six import string_types + from .header import TractogramHeader from .compact_list import CompactList from .tractogram import Tractogram, LazyTractogram @@ -54,11 +57,9 @@ def detect_format(fileobj): except IOError: pass - import os - from ..externals.six import string_types if isinstance(fileobj, string_types): _, ext = os.path.splitext(fileobj) - return FORMATS.get(ext, None) + return FORMATS.get(ext.lower()) return None diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index f76e8ab581..c21c688989 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -28,21 +28,21 @@ def test_is_supported(): assert_false(nib.streamlines.is_supported("")) # Valid file without extension - for tractogram_file in nib.streamlines.FORMATS.values(): + for tfile_cls in nib.streamlines.FORMATS.values(): f = BytesIO() - f.write(tractogram_file.get_magic_number()) + f.write(tfile_cls.get_magic_number()) f.seek(0, os.SEEK_SET) assert_true(nib.streamlines.is_supported(f)) # Wrong extension but right magic number - for tractogram_file in nib.streamlines.FORMATS.values(): + for tfile_cls in nib.streamlines.FORMATS.values(): with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: - f.write(tractogram_file.get_magic_number()) + f.write(tfile_cls.get_magic_number()) f.seek(0, os.SEEK_SET) assert_true(nib.streamlines.is_supported(f)) # Good extension but wrong magic number - for ext, tractogram_file in nib.streamlines.FORMATS.items(): + for ext, tfile_cls in nib.streamlines.FORMATS.items(): with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: f.write(b"pass") f.seek(0, os.SEEK_SET) @@ -53,7 +53,7 @@ def test_is_supported(): assert_false(nib.streamlines.is_supported(f)) # Good extension, string only - for ext, tractogram_file in nib.streamlines.FORMATS.items(): + for ext, tfile_cls in nib.streamlines.FORMATS.items(): f = "my_tractogram" + ext assert_true(nib.streamlines.is_supported(f)) @@ -61,38 +61,43 @@ def test_is_supported(): def test_detect_format(): # Emtpy file/string f = BytesIO() - assert_equal(nib.streamlines.detect_format(f), None) - assert_equal(nib.streamlines.detect_format(""), None) + assert_true(nib.streamlines.detect_format(f) is None) + assert_true(nib.streamlines.detect_format("") is None) # Valid file without extension - for tractogram_file in nib.streamlines.FORMATS.values(): + for tfile_cls in nib.streamlines.FORMATS.values(): f = BytesIO() - f.write(tractogram_file.get_magic_number()) + f.write(tfile_cls.get_magic_number()) f.seek(0, os.SEEK_SET) - assert_equal(nib.streamlines.detect_format(f), tractogram_file) + assert_true(nib.streamlines.detect_format(f) is tfile_cls) # Wrong extension but right magic number - for tractogram_file in nib.streamlines.FORMATS.values(): + for tfile_cls in nib.streamlines.FORMATS.values(): with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: - f.write(tractogram_file.get_magic_number()) + f.write(tfile_cls.get_magic_number()) f.seek(0, os.SEEK_SET) - assert_equal(nib.streamlines.detect_format(f), tractogram_file) + assert_true(nib.streamlines.detect_format(f) is tfile_cls) # Good extension but wrong magic number - for ext, tractogram_file in nib.streamlines.FORMATS.items(): + for ext, tfile_cls in nib.streamlines.FORMATS.items(): with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: f.write(b"pass") f.seek(0, os.SEEK_SET) - assert_equal(nib.streamlines.detect_format(f), None) + assert_true(nib.streamlines.detect_format(f) is None) # Wrong extension, string only f = "my_tractogram.asd" - assert_equal(nib.streamlines.detect_format(f), None) + assert_true(nib.streamlines.detect_format(f) is None) # Good extension, string only - for ext, tractogram_file in nib.streamlines.FORMATS.items(): + for ext, tfile_cls in nib.streamlines.FORMATS.items(): f = "my_tractogram" + ext - assert_equal(nib.streamlines.detect_format(f), tractogram_file) + assert_equal(nib.streamlines.detect_format(f), tfile_cls) + + # Extension should not be case-sensitive. + for ext, tfile_cls in nib.streamlines.FORMATS.items(): + f = "my_tractogram" + ext.upper() + assert_true(nib.streamlines.detect_format(f) is tfile_cls) class TestLoadSave(unittest.TestCase): From 2b9d57bc4390b4d8b8fd8ac988f6cdd3d5fb64bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 1 Dec 2015 15:31:00 -0500 Subject: [PATCH 51/54] Fixed typo --- nibabel/streamlines/compact_list.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index ffe2e9be9d..889c48faea 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -148,7 +148,7 @@ def __getitem__(self, idx): Returns ------- ndarray object(s) - When `idx` is a int, returns a single ndarray. + When `idx` is an int, returns a single ndarray. When `idx` is either a slice or a list, returns a list of ndarrays. """ if isinstance(idx, int) or isinstance(idx, np.integer): @@ -178,7 +178,7 @@ def __getitem__(self, idx): for i, take_it in enumerate(idx) if take_it] return clist - raise TypeError("Index must be a int or a slice! Not " + str(type(idx))) + raise TypeError("Index must be an int or a slice! Not " + str(type(idx))) def __iter__(self): if len(self._lengths) != len(self._offsets): From e180d3a9d05501dfb14d4db950f11ba7c4f6bfeb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 1 Dec 2015 15:52:54 -0500 Subject: [PATCH 52/54] Use isinstance instead of type() whenever possible --- nibabel/streamlines/compact_list.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 889c48faea..35ca7c98ae 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -155,14 +155,14 @@ def __getitem__(self, idx): start = self._offsets[idx] return self._data[start:start+self._lengths[idx]] - elif type(idx) is slice: + elif isinstance(idx, slice): clist = CompactList() clist._data = self._data clist._offsets = self._offsets[idx] clist._lengths = self._lengths[idx] return clist - elif type(idx) is list: + elif isinstance(idx, list): clist = CompactList() clist._data = self._data clist._offsets = [self._offsets[i] for i in idx] @@ -178,7 +178,8 @@ def __getitem__(self, idx): for i, take_it in enumerate(idx) if take_it] return clist - raise TypeError("Index must be an int or a slice! Not " + str(type(idx))) + raise TypeError("Index must be either an int, a slice, a list of int" + " or a ndarray of bool! Not " + str(type(idx))) def __iter__(self): if len(self._lengths) != len(self._offsets): From 876a6f88b7da0d30c7d6a17ae0561f3a8fce3607 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 1 Dec 2015 20:16:09 -0500 Subject: [PATCH 53/54] Added the module streamlines to nibabel --- nibabel/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nibabel/__init__.py b/nibabel/__init__.py index 779f6e8587..36310ca019 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -64,6 +64,7 @@ from .imageclasses import class_map, ext_map, all_image_classes from . import trackvis from . import mriutils +from . import streamlines # be friendly on systems with ancient numpy -- no tests, but at least # importable From 8667fe9ad88255abc42e6ed252d7a539a66d07ca Mon Sep 17 00:00:00 2001 From: Marc-Alexandre Cote Date: Wed, 2 Dec 2015 17:58:29 -0500 Subject: [PATCH 54/54] BF: CompactList.extend with a sliced CompactList was not doing the right thing --- nibabel/streamlines/compact_list.py | 19 ++++++--- .../streamlines/tests/test_compact_list.py | 41 +++++++++++++------ 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 35ca7c98ae..8fb761f312 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -107,16 +107,23 @@ def extend(self, elements): elem = np.asarray(elements[0]) self._data = np.zeros((0, elem.shape[1]), dtype=elem.dtype) + next_offset = self._data.shape[0] + if isinstance(elements, CompactList): - self._data = np.concatenate([self._data, elements._data], axis=0) - lengths = elements._lengths + self._data.resize((self._data.shape[0]+sum(elements._lengths), + self._data.shape[1])) + + for offset, length in zip(elements._offsets, elements._lengths): + self._offsets.append(next_offset) + self._lengths.append(length) + self._data[next_offset:next_offset+length] = elements._data[offset:offset+length] + next_offset += length + else: self._data = np.concatenate([self._data] + list(elements), axis=0) lengths = list(map(len, elements)) - - idx = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 - self._lengths.extend(lengths) - self._offsets.extend(np.cumsum([idx] + lengths).tolist()[:-1]) + self._lengths.extend(lengths) + self._offsets.extend(np.cumsum([next_offset] + lengths).tolist()[:-1]) def copy(self): """ Creates a copy of this ``CompactList`` object. """ diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 7a1c25000f..188102ce85 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -4,6 +4,7 @@ import numpy as np from nose.tools import assert_equal, assert_raises, assert_true +from nibabel.testing import assert_arrays_equal from numpy.testing import assert_array_equal from nibabel.externals.six.moves import zip, zip_longest @@ -170,7 +171,7 @@ def test_compactlist_extend(self): rng = np.random.RandomState(1234) shape = self.clist.shape - new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(5)] + new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(10)] lengths = list(map(len, new_data)) clist.extend(new_data) assert_equal(len(clist), len(self.clist)+len(new_data)) @@ -183,22 +184,38 @@ def test_compactlist_extend(self): # Extend with another `CompactList` object. clist = self.clist.copy() - new_data = CompactList(new_data) - clist.extend(new_data) - assert_equal(len(clist), len(self.clist)+len(new_data)) - assert_array_equal(clist._offsets[-len(new_data):], + new_clist = CompactList(new_data) + clist.extend(new_clist) + assert_equal(len(clist), len(self.clist)+len(new_clist)) + assert_array_equal(clist._offsets[-len(new_clist):], len(self.clist._data) + np.cumsum([0] + lengths[:-1])) - assert_equal(clist._lengths[-len(new_data):], lengths) - assert_array_equal(clist._data[-sum(lengths):], new_data._data) + assert_equal(clist._lengths[-len(new_clist):], lengths) + assert_array_equal(clist._data[-sum(lengths):], new_clist._data) + + # Extend with another `CompactList` object that is a view (e.g. been sliced). + # Need to make sure we extend only the data we need. + clist = self.clist.copy() + new_clist = CompactList(new_data)[::2] + clist.extend(new_clist) + assert_equal(len(clist), len(self.clist)+len(new_clist)) + assert_equal(len(clist._data), len(self.clist._data)+sum(new_clist._lengths)) + assert_array_equal(clist._offsets[-len(new_clist):], + len(self.clist._data) + np.cumsum([0] + new_clist._lengths[:-1])) + + assert_equal(clist._lengths[-len(new_clist):], lengths[::2]) + assert_array_equal(clist._data[-sum(new_clist._lengths):], new_clist.copy()._data) + assert_arrays_equal(clist[-len(new_clist):], new_clist) # Test extending an empty CompactList clist = CompactList() - clist.extend(new_data) - assert_equal(len(clist), len(new_data)) - assert_array_equal(clist._offsets, new_data._offsets) - assert_array_equal(clist._lengths, new_data._lengths) - assert_array_equal(clist._data, new_data._data) + new_clist = CompactList(new_data) + clist.extend(new_clist) + assert_equal(len(clist), len(new_clist)) + assert_array_equal(clist._offsets, new_clist._offsets) + assert_array_equal(clist._lengths, new_clist._lengths) + assert_array_equal(clist._data, new_clist._data) + def test_compactlist_getitem(self): # Get one item