From 9ee5d669306ec1b6fe88222ee6a118e89aea8e45 Mon Sep 17 00:00:00 2001 From: Marc-Alexandre Cote Date: Thu, 20 Feb 2014 19:40:09 -0500 Subject: [PATCH 001/135] First draft --- nibabel/streamlines/__init__.py | 24 ++ nibabel/streamlines/base_format.py | 39 +++ nibabel/streamlines/header.py | 15 + nibabel/streamlines/tests/__init__.py | 0 nibabel/streamlines/tests/test_trk.py | 56 ++++ nibabel/streamlines/trk.py | 464 ++++++++++++++++++++++++++ nibabel/streamlines/utils.py | 47 +++ 7 files changed, 645 insertions(+) create mode 100644 nibabel/streamlines/__init__.py create mode 100644 nibabel/streamlines/base_format.py create mode 100644 nibabel/streamlines/header.py create mode 100644 nibabel/streamlines/tests/__init__.py create mode 100644 nibabel/streamlines/tests/test_trk.py create mode 100644 nibabel/streamlines/trk.py create mode 100644 nibabel/streamlines/utils.py diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py new file mode 100644 index 0000000000..67a9ec56ff --- /dev/null +++ b/nibabel/streamlines/__init__.py @@ -0,0 +1,24 @@ +from nibabel.openers import Opener + +from nibabel.streamlines.utils import detect_format + + +def load(fileobj): + ''' Load a file of streamlines, return instance associated to file format + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the beginning + of the streamlines file's header) + + Returns + ------- + obj : instance of ``StreamlineFile`` + Returns an instance of a ``StreamlineFile`` subclass corresponding to + the format of the streamlines file ``fileobj``. + ''' + fileobj = Opener(fileobj) + streamlines_file = detect_format(fileobj) + return streamlines_file.load(fileobj) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py new file mode 100644 index 0000000000..eb10baa428 --- /dev/null +++ b/nibabel/streamlines/base_format.py @@ -0,0 +1,39 @@ + +class StreamlineFile: + @staticmethod + def get_magic_number(): + raise NotImplementedError() + + @staticmethod + def is_correct_format(cls, fileobj): + raise NotImplementedError() + + def get_header(self): + raise NotImplementedError() + + def get_streamlines(self, as_generator=False): + raise NotImplementedError() + + def get_scalars(self, as_generator=False): + raise NotImplementedError() + + def get_properties(self, as_generator=False): + raise NotImplementedError() + + @classmethod + def load(cls, fileobj): + raise NotImplementedError() + + def save(self, filename): + raise NotImplementedError() + + def __iter__(self): + raise NotImplementedError() + + +class DynamicStreamlineFile(StreamlineFile): + def append(self, streamlines): + raise NotImplementedError() + + def __iadd__(self, streamlines): + return self.append(streamlines) \ No newline at end of file diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py new file mode 100644 index 0000000000..b12c42d8d4 --- /dev/null +++ b/nibabel/streamlines/header.py @@ -0,0 +1,15 @@ +class Field: + NB_STREAMLINES = "nb_streamlines" + STEP_SIZE = "step_size" + METHOD = "method" + NB_SCALARS_PER_POINT = "nb_scalars_per_point" + NB_PROPERTIES_PER_STREAMLINE = "nb_properties_per_streamline" + NB_POINTS = "nb_points" + VOXEL_SIZES = "voxel_sizes" + DIMENSIONS = "dimensions" + MAGIC_NUMBER = "magic_number" + ORIGIN = "origin" + VOXEL_TO_WORLD = "voxel_to_world" + VOXEL_ORDER = "voxel_order" + WORLD_ORDER = "world_order" + ENDIAN = "endian" \ No newline at end of file diff --git a/nibabel/streamlines/tests/__init__.py b/nibabel/streamlines/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py new file mode 100644 index 0000000000..0f01b03dfe --- /dev/null +++ b/nibabel/streamlines/tests/test_trk.py @@ -0,0 +1,56 @@ +# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- +# vi: set ft=python sts=4 ts=4 sw=4 et: + +from pdb import set_trace as dbg + +from os.path import join as pjoin, dirname + +from numpy.testing import (assert_equal, + assert_almost_equal, + assert_array_equal, + assert_array_almost_equal, + assert_raises) + +DATA_PATH = pjoin(dirname(__file__), 'data') + +import nibabel as nib +from nibabel.streamlines.header import Field + + +def test_load_file(): + # Test loading empty file + # empty_file = pjoin(DATA_PATH, "empty.trk") + # empty_trk = nib.streamlines.load(empty_file) + + # hdr = empty_trk.get_header() + # points = empty_trk.get_points(as_generator=False) + # scalars = empty_trk.get_scalars(as_generator=False) + # properties = empty_trk.get_properties(as_generator=False) + + # assert_equal(hdr[Field.NB_STREAMLINES], 0) + # assert_equal(len(points), 0) + # assert_equal(len(scalars), 0) + # assert_equal(len(properties), 0) + + # for i in empty_trk: pass # Check if we can iterate through the streamlines. + + # Test loading non-empty file + trk_file = pjoin(DATA_PATH, "uncinate.trk") + trk = nib.streamlines.load(trk_file) + + hdr = trk.get_header() + points = trk.get_points(as_generator=False) + 1/0 + scalars = trk.get_scalars(as_generator=False) + properties = trk.get_properties(as_generator=False) + + assert_equal(hdr[Field.NB_STREAMLINES] > 0, True) + assert_equal(len(points) > 0, True) + #assert_equal(len(scalars), 0) + #assert_equal(len(properties), 0) + + for i in trk: pass # Check if we can iterate through the streamlines. + + +if __name__ == "__main__": + test_load_file() \ No newline at end of file diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py new file mode 100644 index 0000000000..6a02397cf6 --- /dev/null +++ b/nibabel/streamlines/trk.py @@ -0,0 +1,464 @@ +# Documentation available here: +# http://www.trackvis.org/docs/?subsect=fileformat +from pdb import set_trace as dbg + +import os +import warnings +import numpy as np +from numpy.lib.recfunctions import append_fields + +from nibabel.openers import Opener +from nibabel.volumeutils import (native_code, swapped_code, endian_codes) + +from nibabel.streamlines.base_format import DynamicStreamlineFile +from nibabel.streamlines.header import Field + +# Definition of trackvis header structure. +# See http://www.trackvis.org/docs/?subsect=fileformat +# See http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html +header_1_dtd = [ + (Field.MAGIC_NUMBER, 'S6'), + (Field.DIMENSIONS, 'h', 3), + (Field.VOXEL_SIZES, 'f4', 3), + (Field.ORIGIN, 'f4', 3), + (Field.NB_SCALARS_PER_POINT, 'h'), + ('scalar_name', 'S20', 10), + (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), + ('property_name', 'S20', 10), + ('reserved', 'S508'), + (Field.VOXEL_ORDER, 'S4'), + ('pad2', 'S4'), + ('image_orientation_patient', 'f4', 6), + ('pad1', 'S2'), + ('invert_x', 'S1'), + ('invert_y', 'S1'), + ('invert_z', 'S1'), + ('swap_xy', 'S1'), + ('swap_yz', 'S1'), + ('swap_zx', 'S1'), + (Field.NB_STREAMLINES, 'i4'), + ('version', 'i4'), + ('hdr_size', 'i4'), + ] + +# Version 2 adds a 4x4 matrix giving the affine transformtation going +# from voxel coordinates in the referenced 3D voxel matrix, to xyz +# coordinates (axes L->R, P->A, I->S). IF (0 based) value [3, 3] from +# this matrix is 0, this means the matrix is not recorded. +header_2_dtd = [ + (Field.MAGIC_NUMBER, 'S6'), + (Field.DIMENSIONS, 'h', 3), + (Field.VOXEL_SIZES, 'f4', 3), + (Field.ORIGIN, 'f4', 3), + (Field.NB_SCALARS_PER_POINT, 'h'), + ('scalar_name', 'S20', 10), + (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), + ('property_name', 'S20', 10), + (Field.VOXEL_TO_WORLD, 'f4', (4,4)), # new field for version 2 + ('reserved', 'S444'), + (Field.VOXEL_ORDER, 'S4'), + ('pad2', 'S4'), + ('image_orientation_patient', 'f4', 6), + ('pad1', 'S2'), + ('invert_x', 'S1'), + ('invert_y', 'S1'), + ('invert_z', 'S1'), + ('swap_xy', 'S1'), + ('swap_yz', 'S1'), + ('swap_zx', 'S1'), + (Field.NB_STREAMLINES, 'i4'), + ('version', 'i4'), + ('hdr_size', 'i4'), + ] + +# Full header numpy dtypes +header_1_dtype = np.dtype(header_1_dtd) +header_2_dtype = np.dtype(header_2_dtd) + + +class HeaderError(Exception): + pass + + +class DataError(Exception): + pass + + +class TrkFile(DynamicStreamlineFile): + MAGIC_NUMBER = "TRACK" + OFFSET = 1000 + + def __init__(self, hdr, streamlines, scalars, properties): + self.filename = None + + self.hdr = hdr + self.streamlines = streamlines + self.scalars = scalars + self.properties = properties + + ##### + # Static Methods + ### + @classmethod + def get_magic_number(cls): + ''' Return TRK's magic number ''' + return cls.MAGIC_NUMBER + + @classmethod + def is_correct_format(cls, fileobj): + ''' Check if the file is in TRK format. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data) + + Returns + ------- + is_correct_format : boolean + Returns True if `fileobj` is in TRK format, False otherwise. + ''' + with Opener(fileobj) as fileobj: + magic_number = fileobj.read(5) + fileobj.seek(-5, os.SEEK_CUR) + return magic_number == cls.MAGIC_NUMBER + + return False + + @classmethod + def load(cls, fileobj): + hdr = {} + pos_header = 0 + pos_data = 0 + + with Opener(fileobj) as fileobj: + pos_header = fileobj.tell() + + ##### + # Read header + ### + hdr_str = fileobj.read(header_2_dtype.itemsize) + hdr = np.fromstring(string=hdr_str, dtype=header_2_dtype) + + if hdr['version'] == 1: + hdr = np.fromstring(string=hdr_str, dtype=header_1_dtype) + elif hdr['version'] == 2: + pass # Nothing more to do here + else: + raise HeaderError('NiBabel only supports versions 1 and 2.') + + # Make header a dictionnary instead of ndarray + hdr = dict(zip(hdr.dtype.names, hdr[0])) + + # Check endianness + #hdr = append_fields(hdr, Field.ENDIAN, [native_code], usemask=False) + hdr[Field.ENDIAN] = native_code + if hdr['hdr_size'] != 1000: + hdr[Field.ENDIAN] = swapped_code + hdr = hdr.newbyteorder() + if hdr['hdr_size'] != 1000: + raise HeaderError('Invalid hdr_size of {0}'.format(hdr['hdr_size'])) + + # Add more header fields implied by trk format. + #hdr = append_fields(hdr, Field.WORLD_ORDER, ["RAS"], usemask=False) + hdr[Field.WORLD_ORDER] = "RAS" + + pos_data = fileobj.tell() + + i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") + f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") + + nb_streamlines = 0 + + #Either verify the number of streamlines specified in the header is correct or + # count the actual number of streamlines in case it was not specified in the header. + while True: + # Read number of points of the streamline + buf = fileobj.read(i4_dtype.itemsize) + + if buf == '': + break # EOF + + nb_pts = np.fromstring(buf, + dtype=i4_dtype, + count=1) + + bytes_to_skip = nb_pts * 3 # x, y, z coordinates + bytes_to_skip += nb_pts * hdr[Field.NB_SCALARS_PER_POINT] + bytes_to_skip += hdr[Field.NB_PROPERTIES_PER_STREAMLINE] + + # Seek to the next streamline in the file. + fileobj.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) + + nb_streamlines += 1 + + if hdr[Field.NB_STREAMLINES] != nb_streamlines: + warnings.warn('The number of streamlines specified in header ({1}) does not match ' + + 'the actual number of streamlines contained in this file ({1}). ' + + 'The latter will be used.'.format(hdr[Field.NB_STREAMLINES], nb_streamlines)) + + hdr[Field.NB_STREAMLINES] = nb_streamlines + + trk_file = cls(hdr, [], [], []) + trk_file.pos_header = pos_header + trk_file.pos_data = pos_data + trk_file.streamlines + + return trk_file + # cls(hdr, streamlines, scalars, properties) + + def get_header(self): + return self.hdr + + def get_points(self, as_generator=False): + self.fileobj.seek(self.pos_data, os.SEEK_SET) + pos = self.pos_data + + i4_dtype = np.dtype(self.hdr[Field.ENDIAN] + "i4") + f4_dtype = np.dtype(self.hdr[Field.ENDIAN] + "f4") + + for i in range(self.hdr[Field.NB_STREAMLINES]): + # Read number of points of the streamline + nb_pts = np.fromstring(self.fileobj.read(i4_dtype.itemsize), + dtype=i4_dtype, + count=1) + + # Read points of the streamline + pts = np.fromstring(self.fileobj.read(nb_pts * 3 * i4_dtype.itemsize), + dtype=[f4_dtype, f4_dtype, f4_dtype], + count=nb_pts) + + pos = self.fileobj.tell() + yield pts + self.fileobj.seek(pos, os.SEEK_SET) + + bytes_to_skip = nb_pts * self.hdr[Field.NB_SCALARS_PER_POINT] + bytes_to_skip += self.hdr[Field.NB_PROPERTIES_PER_STREAMLINE] + + # Seek to the next streamline in the file. + self.fileobj.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) + + ##### + # Methods + ### + + + + +# import os +# import logging +# import numpy as np + +# from tractconverter.formats.header import Header as H + + +# def readBinaryBytes(f, nbBytes, dtype): +# buff = f.read(nbBytes * dtype.itemsize) +# return np.frombuffer(buff, dtype=dtype) + + +# class TRK: +# # self.hdr +# # self.filename +# # self.hdr[H.ENDIAN] +# # self.FIBER_DELIMITER +# # self.END_DELIMITER + +# @staticmethod +# def create(filename, hdr, anatFile=None): +# f = open(filename, 'wb') +# f.write(TRK.MAGIC_NUMBER + "\n") +# f.close() + +# trk = TRK(filename, load=False) +# trk.hdr = hdr +# trk.writeHeader() + +# return trk + +# ##### +# # Methods +# ### +# def __init__(self, filename, anatFile=None, load=True): +# if not TRK._check(filename): +# raise NameError("Not a TRK file.") + +# self.filename = filename +# self.hdr = {} +# if load: +# self._load() + +# def _load(self): +# f = open(self.filename, 'rb') + +# ##### +# # Read header +# ### +# self.hdr[H.MAGIC_NUMBER] = f.read(6) +# self.hdr[H.DIMENSIONS] = np.frombuffer(f.read(6), dtype='i4') +# self.hdr["version"] = self.hdr["version"].astype('>i4') +# self.hdr["hdr_size"] = self.hdr["hdr_size"].astype('>i4') + +# nb_fibers = 0 +# self.hdr[H.NB_POINTS] = 0 + +# #Either verify the number of streamlines specified in the header is correct or +# # count the actual number of streamlines in case it is not specified in the header. +# remainingBytes = os.path.getsize(self.filename) - self.OFFSET +# while remainingBytes > 0: +# # Read points +# nbPoints = readBinaryBytes(f, 1, np.dtype(self.hdr[H.ENDIAN] + "i4"))[0] +# self.hdr[H.NB_POINTS] += nbPoints +# # This seek is used to go to the next points number indication in the file. +# f.seek((nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]) +# + self.hdr[H.NB_PROPERTIES_PER_STREAMLINE]) * 4, 1) # Relative seek +# remainingBytes -= (nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]) +# + self.hdr[H.NB_PROPERTIES_PER_STREAMLINE]) * 4 + 4 +# nb_fibers += 1 + +# if self.hdr[H.NB_FIBERS] != nb_fibers: +# logging.warn('The number of streamlines specified in header ({1}) does not match ' + +# 'the actual number of streamlines contained in this file ({1}). ' + +# 'The latter will be used.'.format(self.hdr[H.NB_FIBERS], nb_fibers)) + +# self.hdr[H.NB_FIBERS] = nb_fibers + +# f.close() + +# def writeHeader(self): +# # Get the voxel size and format it as an array. +# voxel_sizes = np.asarray(self.hdr.get(H.VOXEL_SIZES, (1.0, 1.0, 1.0)), dtype=' 0: +# # Read points +# nbPoints = readBinaryBytes(f, 1, np.dtype(self.hdr[H.ENDIAN] + "i4"))[0] +# ptsAndScalars = readBinaryBytes(f, +# nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]), +# np.dtype(self.hdr[H.ENDIAN] + "f4")) + +# newShape = [-1, 3 + self.hdr[H.NB_SCALARS_PER_POINT]] +# ptsAndScalars = ptsAndScalars.reshape(newShape) + +# pointsWithoutScalars = ptsAndScalars[:, 0:3] +# yield pointsWithoutScalars + +# # For now, we do not process the tract properties, so just skip over them. +# remainingBytes -= nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]) * 4 + 4 +# remainingBytes -= self.hdr[H.NB_PROPERTIES_PER_STREAMLINE] * 4 +# cpt += 1 + +# f.close() + +# def __str__(self): +# text = "" +# text += "MAGIC NUMBER: {0}".format(self.hdr[H.MAGIC_NUMBER]) +# text += "v.{0}".format(self.hdr['version']) +# text += "dim: {0}".format(self.hdr[H.DIMENSIONS]) +# text += "voxel_sizes: {0}".format(self.hdr[H.VOXEL_SIZES]) +# text += "orgin: {0}".format(self.hdr[H.ORIGIN]) +# text += "nb_scalars: {0}".format(self.hdr[H.NB_SCALARS_PER_POINT]) +# text += "scalar_name:\n {0}".format("\n".join(self.hdr['scalar_name'])) +# text += "nb_properties: {0}".format(self.hdr[H.NB_PROPERTIES_PER_STREAMLINE]) +# text += "property_name:\n {0}".format("\n".join(self.hdr['property_name'])) +# text += "vox_to_world: {0}".format(self.hdr[H.VOXEL_TO_WORLD]) +# text += "world_order: {0}".format(self.hdr[H.WORLD_ORDER]) +# text += "voxel_order: {0}".format(self.hdr[H.VOXEL_ORDER]) +# text += "image_orientation_patient: {0}".format(self.hdr['image_orientation_patient']) +# text += "pad1: {0}".format(self.hdr['pad1']) +# text += "pad2: {0}".format(self.hdr['pad2']) +# text += "invert_x: {0}".format(self.hdr['invert_x']) +# text += "invert_y: {0}".format(self.hdr['invert_y']) +# text += "invert_z: {0}".format(self.hdr['invert_z']) +# text += "swap_xy: {0}".format(self.hdr['swap_xy']) +# text += "swap_yz: {0}".format(self.hdr['swap_yz']) +# text += "swap_zx: {0}".format(self.hdr['swap_zx']) +# text += "n_count: {0}".format(self.hdr[H.NB_FIBERS]) +# text += "hdr_size: {0}".format(self.hdr['hdr_size']) +# text += "endianess: {0}".format(self.hdr[H.ENDIAN]) + +# return text diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py new file mode 100644 index 0000000000..511f27fe01 --- /dev/null +++ b/nibabel/streamlines/utils.py @@ -0,0 +1,47 @@ +import os + +import nibabel as nib + +from nibabel.streamlines.trk import TrkFile +#from nibabel.streamlines.tck import TckFile +#from nibabel.streamlines.vtk import VtkFile +#from nibabel.streamlines.fib import FibFile + +# Supported format +FORMATS = {"trk": TrkFile, + #"tck": TckFile, + #"vtk": VtkFile, + #"fib": FibFile, + } + +def is_supported(fileobj): + return detect_format(fileobj) is not None + + +def detect_format(fileobj): + for format in FORMATS.values(): + if format.is_correct_format(fileobj): + return format + + if isinstance(fileobj, basestring): + _, ext = os.path.splitext(fileobj) + return FORMATS.get(ext, None) + + return None + + +def convert(in_fileobj, out_filename): + in_fileobj = nib.streamlines.load(in_fileobj) + out_format = nib.streamlines.guess_format(out_filename) + + hdr = in_fileobj.get_header() + points = in_fileobj.get_points(as_generator=True) + scalars = in_fileobj.get_scalars(as_generator=True) + properties = in_fileobj.get_properties(as_generator=True) + + out_fileobj = out_format(hdr, points, scalars, properties) + out_fileobj.save(out_filename) + + +def change_space(streamline_file, new_point_space): + pass \ No newline at end of file From 9cf22a55651a3b65ea658cad38bc36f782bd8ebf Mon Sep 17 00:00:00 2001 From: Marc-Alexandre Cote Date: Fri, 18 Jul 2014 22:47:31 -0400 Subject: [PATCH 002/135] A working prototype of the new streamlines API --- nibabel/__init__.py | 5 +- nibabel/benchmarks/bench_streamlines.py | 127 +++ nibabel/externals/six.py | 1 + nibabel/streamlines/__init__.py | 27 +- nibabel/streamlines/base_format.py | 186 ++++- nibabel/streamlines/header.py | 44 +- nibabel/streamlines/tests/data/complex.trk | Bin 0 -> 1228 bytes nibabel/streamlines/tests/data/empty.trk | Bin 0 -> 1000 bytes nibabel/streamlines/tests/data/simple.trk | Bin 0 -> 1108 bytes nibabel/streamlines/tests/test_base_format.py | 117 +++ nibabel/streamlines/tests/test_trk.py | 365 +++++++-- nibabel/streamlines/tests/test_utils.py | 323 ++++++++ nibabel/streamlines/trk.py | 750 +++++++++--------- nibabel/streamlines/utils.py | 126 ++- setup.py | 2 + 15 files changed, 1583 insertions(+), 490 deletions(-) create mode 100644 nibabel/benchmarks/bench_streamlines.py create mode 100644 nibabel/streamlines/tests/data/complex.trk create mode 100644 nibabel/streamlines/tests/data/empty.trk create mode 100644 nibabel/streamlines/tests/data/simple.trk create mode 100644 nibabel/streamlines/tests/test_base_format.py create mode 100644 nibabel/streamlines/tests/test_utils.py diff --git a/nibabel/__init__.py b/nibabel/__init__.py index 4d8791d7d9..8cb1d95e2e 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -23,9 +23,9 @@ img3 = nib.load('spm_file.img') data = img1.get_data() - affine = img1.affine + affine = img1.get_affine() - print(img1) + print img1 nib.save(img1, 'my_file_copy.nii.gz') @@ -63,6 +63,7 @@ apply_orientation, aff2axcodes) from .imageclasses import class_map, ext_map, all_image_classes from . import trackvis +from .streamlines import Streamlines from . import mriutils from . import viewers diff --git a/nibabel/benchmarks/bench_streamlines.py b/nibabel/benchmarks/bench_streamlines.py new file mode 100644 index 0000000000..3a2e3ab39d --- /dev/null +++ b/nibabel/benchmarks/bench_streamlines.py @@ -0,0 +1,127 @@ +""" Benchmarks for load and save of streamlines + +Run benchmarks with:: + + import nibabel as nib + nib.bench() + +If you have doctests enabled by default in nose (with a noserc file or +environment variable), and you have a numpy version <= 1.6.1, this will also run +the doctests, let's hope they pass. + +Run this benchmark with: + + nosetests -s --match '(?:^|[\\b_\\.//-])[Bb]ench' /path/to/bench_streamlines.py +""" +from __future__ import division, print_function + +import os +import numpy as np + +from nibabel.externals.six import BytesIO +from nibabel.externals.six.moves import zip + +from nibabel.testing import assert_arrays_equal + +from numpy.testing import assert_array_equal +from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines import TrkFile + +import nibabel as nib +import nibabel.trackvis as tv + +from numpy.testing import measure + + +def bench_load_trk(): + NB_STREAMLINES = 1000 + NB_POINTS = 1000 + points = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)] + repeat = 20 + + trk_file = BytesIO() + trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES)) + tv.write(trk_file, trk) + + mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, lazy_load=False)', repeat) + print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) + + mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file)', repeat) + print("Old: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) + print("Speedup of %2f" % (mtime_old/mtime_new)) + + # Points and scalars + scalars = [np.random.rand(NB_POINTS, 10).astype('float32') for i in range(NB_STREAMLINES)] + + trk_file = BytesIO() + trk = list(zip(points, scalars, [None]*NB_STREAMLINES)) + tv.write(trk_file, trk) + + mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, lazy_load=False)', repeat) + print("New: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new)) + + mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file)', repeat) + print("Old: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old)) + print("Speedup of %2f" % (mtime_old/mtime_new)) + + +def bench_save_trk(): + NB_STREAMLINES = 100 + NB_POINTS = 1000 + points = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)] + repeat = 10 + + # Only points + streamlines = Streamlines(points) + trk_file_new = BytesIO() + + mtime_new = measure('trk_file_new.seek(0, os.SEEK_SET); TrkFile.save(streamlines, trk_file_new)', repeat) + print("\nNew: Saved %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) + + trk_file_old = BytesIO() + trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES)) + mtime_old = measure('trk_file_old.seek(0, os.SEEK_SET); tv.write(trk_file_old, trk)', repeat) + print("Old: Saved %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) + print("Speedup of %2f" % (mtime_old/mtime_new)) + + trk_file_new.seek(0, os.SEEK_SET) + trk_file_old.seek(0, os.SEEK_SET) + streams, hdr = tv.read(trk_file_old) + + for pts, A in zip(points, streams): + assert_array_equal(pts, A[0]) + + trk = nib.streamlines.load(trk_file_new, lazy_load=False) + + assert_arrays_equal(points, trk.points) + + # Points and scalars + scalars = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)] + streamlines = Streamlines(points, scalars=scalars) + trk_file_new = BytesIO() + + mtime_new = measure('trk_file_new.seek(0, os.SEEK_SET); TrkFile.save(streamlines, trk_file_new)', repeat) + print("New: Saved %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new)) + + trk_file_old = BytesIO() + trk = list(zip(points, scalars, [None]*NB_STREAMLINES)) + mtime_old = measure('trk_file_old.seek(0, os.SEEK_SET); tv.write(trk_file_old, trk)', repeat) + print("Old: Saved %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old)) + print("Speedup of %2f" % (mtime_old/mtime_new)) + + trk_file_new.seek(0, os.SEEK_SET) + trk_file_old.seek(0, os.SEEK_SET) + streams, hdr = tv.read(trk_file_old) + + for pts, scal, A in zip(points, scalars, streams): + assert_array_equal(pts, A[0]) + assert_array_equal(scal, A[1]) + + trk = nib.streamlines.load(trk_file_new, lazy_load=False) + + assert_arrays_equal(points, trk.points) + assert_arrays_equal(scalars, trk.scalars) + + +if __name__ == '__main__': + bench_save_trk() diff --git a/nibabel/externals/six.py b/nibabel/externals/six.py index eae31454ae..b23166effd 100644 --- a/nibabel/externals/six.py +++ b/nibabel/externals/six.py @@ -143,6 +143,7 @@ class _MovedItems(types.ModuleType): MovedAttribute("StringIO", "StringIO", "io"), MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), + MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), MovedModule("builtins", "__builtin__"), MovedModule("configparser", "ConfigParser"), diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 67a9ec56ff..cd7cd56a21 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,24 +1,9 @@ -from nibabel.openers import Opener -from nibabel.streamlines.utils import detect_format +from nibabel.streamlines.utils import load, save +from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.header import Field -def load(fileobj): - ''' Load a file of streamlines, return instance associated to file format - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the beginning - of the streamlines file's header) - - Returns - ------- - obj : instance of ``StreamlineFile`` - Returns an instance of a ``StreamlineFile`` subclass corresponding to - the format of the streamlines file ``fileobj``. - ''' - fileobj = Opener(fileobj) - streamlines_file = detect_format(fileobj) - return streamlines_file.load(fileobj) +from nibabel.streamlines.trk import TrkFile +#from nibabel.streamlines.trk import TckFile +#from nibabel.streamlines.trk import VtkFile diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index eb10baa428..43e3b8616e 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -1,39 +1,193 @@ -class StreamlineFile: - @staticmethod - def get_magic_number(): - raise NotImplementedError() +from nibabel.streamlines.header import Field - @staticmethod - def is_correct_format(cls, fileobj): - raise NotImplementedError() +from ..externals.six.moves import zip_longest + + +class HeaderError(Exception): + pass + + +class DataError(Exception): + pass + + +class Streamlines(object): + ''' Class containing information about streamlines. + + Streamlines objects have three main properties: ``points``, ``scalars`` + and ``properties``. Streamlines objects can be iterate over producing + tuple of ``points``, ``scalars`` and ``properties`` for each streamline. + + Parameters + ---------- + points : sequence of ndarray of shape (N, 3) + Sequence of T streamlines. One streamline is an ndarray of shape (N, 3) + where N is the number of points in a streamline. + + scalars : sequence of ndarray of shape (N, M) + Sequence of T ndarrays of shape (N, M) where T is the number of + streamlines defined by ``points``, N is the number of points + for a particular streamline and M is the number of scalars + associated to each point (excluding the three coordinates). + + properties : sequence of ndarray of shape (P,) + Sequence of T ndarrays of shape (P,) where T is the number of + streamlines defined by ``points``, P is the number of properties + associated to each streamlines. + + hdr : dict + Header containing meta information about the streamlines. For a list + of common header's fields to use as keys see `nibabel.streamlines.Field`. + ''' + def __init__(self, points=[], scalars=[], properties=[], hdr={}): + self.hdr = hdr + + self.points = points + self.scalars = scalars + self.properties = properties + self.data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) + + try: + self.length = len(points) + except: + if Field.NB_STREAMLINES in hdr: + self.length = hdr[Field.NB_STREAMLINES] + else: + raise HeaderError(("Neither parameter 'points' nor 'hdr' contain information about" + " number of streamlines. Use key '{0}' to set the number of " + "streamlines in 'hdr'.").format(Field.NB_STREAMLINES)) def get_header(self): - raise NotImplementedError() + return self.hdr + + @property + def points(self): + return self._points() + + @points.setter + def points(self, value): + self._points = value if callable(value) else (lambda: value) + + @property + def scalars(self): + return self._scalars() + + @scalars.setter + def scalars(self, value): + self._scalars = value if callable(value) else lambda: value + + @property + def properties(self): + return self._properties() - def get_streamlines(self, as_generator=False): + @properties.setter + def properties(self, value): + self._properties = value if callable(value) else lambda: value + + def __iter__(self): + return self.data() + + def __len__(self): + return self.length + + +class StreamlinesFile: + ''' Convenience class to encapsulate streamlines file format. ''' + + @classmethod + def get_magic_number(cls): + ''' Return streamlines file's magic number. ''' raise NotImplementedError() - def get_scalars(self, as_generator=False): + @classmethod + def is_correct_format(cls, fileobj): + ''' Check if the file has the right streamlines file format. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header) + + Returns + ------- + is_correct_format : boolean + Returns True if `fileobj` is in the right streamlines file format. + ''' raise NotImplementedError() - def get_properties(self, as_generator=False): + @classmethod + def get_empty_header(cls): + ''' Return an empty streamlines file's header. ''' raise NotImplementedError() @classmethod - def load(cls, fileobj): + def load(cls, fileobj, lazy_load=True): + ''' Loads streamlines from a file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header) + + lazy_load : boolean + Load streamlines in a lazy manner i.e. they will not be kept + in memory. For postprocessing speed, turn off this option. + + Returns + ------- + streamlines : Streamlines object + Returns an object containing streamlines' data and header + information. See 'nibabel.Streamlines'. + ''' raise NotImplementedError() - def save(self, filename): + @classmethod + def save(cls, streamlines, fileobj): + ''' Saves streamlines to a file-like object. + + Parameters + ---------- + streamlines : Streamlines object + Object containing streamlines' data and header information. + See 'nibabel.Streamlines'. + + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + opened and ready to write. + ''' raise NotImplementedError() - def __iter__(self): + @staticmethod + def pretty_print(streamlines): + ''' Gets a formatted string contaning header's information + relevant to the streamlines file format. + + Parameters + ---------- + streamlines : Streamlines object + Object containing streamlines' data and header information. + See 'nibabel.Streamlines'. + + Returns + ------- + info : string + Header's information relevant to the streamlines file format. + ''' raise NotImplementedError() -class DynamicStreamlineFile(StreamlineFile): +class DynamicStreamlineFile(StreamlinesFile): + ''' Convenience class to encapsulate streamlines file format + that supports appending streamlines to an existing file. + ''' + def append(self, streamlines): raise NotImplementedError() def __iadd__(self, streamlines): - return self.append(streamlines) \ No newline at end of file + return self.append(streamlines) diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index b12c42d8d4..aa40d3bb97 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -1,4 +1,11 @@ +from nibabel.orientations import aff2axcodes + + class Field: + """ Header fields common to multiple streamlines file formats. + + In IPython, use `nibabel.streamlines.Field??` to list them. + """ NB_STREAMLINES = "nb_streamlines" STEP_SIZE = "step_size" METHOD = "method" @@ -12,4 +19,39 @@ class Field: VOXEL_TO_WORLD = "voxel_to_world" VOXEL_ORDER = "voxel_order" WORLD_ORDER = "world_order" - ENDIAN = "endian" \ No newline at end of file + ENDIAN = "endian" + + +def create_header_from_nifti(img): + ''' Creates a common streamlines' header using a nifti image. + + Based on the information of the nifti image a dictionnary is created + containing the following keys: `Field.ORIGIN`, `Field.DIMENSIONS`, + `Field.VOXEL_SIZES`, `Field.VOXEL_TO_WORLD`, `Field.WORLD_ORDER` + and `Field.VOXEL_ORDER`. + + Parameters + ---------- + img : Nifti1Image object + Image containing information about the anatomy where streamlines + were created. + + Returns + ------- + hdr : dict + Header containing meta information about streamlines extracted + from the anatomy. + ''' + img_header = img.get_header() + affine = img_header.get_best_affine() + + hdr = {} + + hdr[Field.ORIGIN] = affine[:3, -1] + hdr[Field.DIMENSIONS] = img_header.get_data_shape()[:3] + hdr[Field.VOXEL_SIZES] = img_header.get_zooms()[:3] + hdr[Field.VOXEL_TO_WORLD] = affine + hdr[Field.WORLD_ORDER] = "RAS" # Nifti space + hdr[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine)) + + return hdr diff --git a/nibabel/streamlines/tests/data/complex.trk b/nibabel/streamlines/tests/data/complex.trk new file mode 100644 index 0000000000000000000000000000000000000000..2a96a7bd35f6ce69f05907a351466ac838de1045 GIT binary patch literal 1228 zcmWFua&-1)U<5-3h6Z~CW`F}`IBTdgn1;FsklYW7D|A4K7`j=b{GlEKpnMBT@h^Z3 zM!09dX~Y4N&mn@(Dj03LHZ8>)ja@W21g)<0+6@>kgov52590AKz;xaC!mPia=8QD;O77UQAr@i literal 0 HcmV?d00001 diff --git a/nibabel/streamlines/tests/data/empty.trk b/nibabel/streamlines/tests/data/empty.trk new file mode 100644 index 0000000000000000000000000000000000000000..023b3c5905f222fbd76ce9f82ca72be319b427f6 GIT binary patch literal 1000 ycmWFua&-1)Sio=sh#43f>=78q9R-6p1VC|x4k!^rH*1tX972Ez=!qB13=9BLi3y|t literal 0 HcmV?d00001 diff --git a/nibabel/streamlines/tests/data/simple.trk b/nibabel/streamlines/tests/data/simple.trk new file mode 100644 index 0000000000000000000000000000000000000000..dc5eff4adc9cd919e7bfe31a37485ecc51348109 GIT binary patch literal 1108 zcmWFua&-1)U<5-3h6Z~CW*7y7Is`y*g$^hYLpN)bKh#5j8R!8fAbtU4Fv2|pP9qK= zaR`9$85kTKfO#K?7dWs&Wguk%15gYh$G~s^$bSID42}#80zj+)#0Eg@0K@@6oZtum D?XVOc literal 0 HcmV?d00001 diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py new file mode 100644 index 0000000000..25956bcfed --- /dev/null +++ b/nibabel/streamlines/tests/test_base_format.py @@ -0,0 +1,117 @@ +import os +import unittest +import numpy as np + +from nibabel.testing import assert_arrays_equal +from nose.tools import assert_equal, assert_raises + +from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import HeaderError +from nibabel.streamlines.header import Field + +DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') + + +class TestStreamlines(unittest.TestCase): + + def setUp(self): + self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") + self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") + self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") + + self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11, 2.22], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + + def test_streamlines_creation_from_arrays(self): + # Empty + streamlines = Streamlines() + assert_equal(len(streamlines), 0) + + # TODO: Should Streamlines have a default header? It could have + # NB_STREAMLINES, NB_SCALARS_PER_POINT and NB_PROPERTIES_PER_STREAMLINE + # already set. + hdr = streamlines.get_header() + assert_equal(len(hdr), 0) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in streamlines: + pass + + # Only points + streamlines = Streamlines(points=self.points) + assert_equal(len(streamlines), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in streamlines: + pass + + # Only scalars + streamlines = Streamlines(scalars=self.colors) + assert_equal(len(streamlines), 0) + assert_arrays_equal(streamlines.scalars, self.colors) + + # TODO: is it a faulty behavior? + assert_equal(len(list(streamlines)), len(self.colors)) + + # Points, scalars and properties + streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) + assert_equal(len(streamlines), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + assert_arrays_equal(streamlines.scalars, self.colors) + assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in streamlines: + pass + + def test_streamlines_creation_from_generators(self): + # Points, scalars and properties + points = (x for x in self.points) + scalars = (x for x in self.colors) + properties = (x for x in self.mean_curvature_torsion) + + assert_raises(HeaderError, Streamlines, points, scalars, properties) + + hdr = {Field.NB_STREAMLINES: len(self.points)} + streamlines = Streamlines(points, scalars, properties, hdr) + assert_equal(len(streamlines), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + assert_arrays_equal(streamlines.scalars, self.colors) + assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + + # Have been consumed + assert_equal(len(list(streamlines)), 0) + assert_equal(len(list(streamlines.points)), 0) + assert_equal(len(list(streamlines.scalars)), 0) + assert_equal(len(list(streamlines.properties)), 0) + + def test_streamlines_creation_from_functions(self): + # Points, scalars and properties + points = lambda: (x for x in self.points) + scalars = lambda: (x for x in self.colors) + properties = lambda: (x for x in self.mean_curvature_torsion) + + assert_raises(HeaderError, Streamlines, points, scalars, properties) + + hdr = {Field.NB_STREAMLINES: len(self.points)} + streamlines = Streamlines(points, scalars, properties, hdr) + assert_equal(len(streamlines), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + assert_arrays_equal(streamlines.scalars, self.colors) + assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + + # Have been consumed but lambda functions get re-called. + assert_equal(len(list(streamlines)), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + assert_arrays_equal(streamlines.scalars, self.colors) + assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 0f01b03dfe..251be7a61d 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -1,56 +1,323 @@ -# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*- -# vi: set ft=python sts=4 ts=4 sw=4 et: +import os +import unittest +import numpy as np -from pdb import set_trace as dbg +from nibabel.externals.six import BytesIO -from os.path import join as pjoin, dirname - -from numpy.testing import (assert_equal, - assert_almost_equal, - assert_array_equal, - assert_array_almost_equal, - assert_raises) - -DATA_PATH = pjoin(dirname(__file__), 'data') +from nibabel.testing import assert_arrays_equal +from nose.tools import assert_equal, assert_raises import nibabel as nib +from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import DataError, HeaderError from nibabel.streamlines.header import Field +from nibabel.streamlines.trk import TrkFile + +DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') + + +class TestTRK(unittest.TestCase): + + def setUp(self): + self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") + self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") + self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") + + self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11, 2.22], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + + def test_load_empty_file(self): + empty_trk = nib.streamlines.load(self.empty_trk_filename, lazy_load=False) + + hdr = empty_trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], 0) + assert_equal(len(empty_trk), 0) + + points = empty_trk.points + assert_equal(len(points), 0) + + scalars = empty_trk.scalars + assert_equal(len(scalars), 0) + + properties = empty_trk.properties + assert_equal(len(properties), 0) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in empty_trk: + pass + + def test_load_simple_file(self): + simple_trk = nib.streamlines.load(self.simple_trk_filename, lazy_load=False) + + hdr = simple_trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(simple_trk), len(self.points)) + + points = simple_trk.points + assert_arrays_equal(points, self.points) + + scalars = simple_trk.scalars + assert_equal(len(scalars), len(self.points)) + + properties = simple_trk.properties + assert_equal(len(properties), len(self.points)) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in simple_trk: + pass + + # Test lazy_load + simple_trk = nib.streamlines.load(self.simple_trk_filename, lazy_load=True) + + hdr = simple_trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(simple_trk), len(self.points)) + + points = simple_trk.points + assert_arrays_equal(points, self.points) + + scalars = simple_trk.scalars + assert_equal(len(list(scalars)), len(self.points)) + + properties = simple_trk.properties + assert_equal(len(list(properties)), len(self.points)) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in simple_trk: + pass + + def test_load_complex_file(self): + complex_trk = nib.streamlines.load(self.complex_trk_filename, lazy_load=False) + + hdr = complex_trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_trk), len(self.points)) + + points = complex_trk.points + assert_arrays_equal(points, self.points) + + scalars = complex_trk.scalars + assert_arrays_equal(scalars, self.colors) + + properties = complex_trk.properties + assert_arrays_equal(properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_trk: + pass + + complex_trk = nib.streamlines.load(self.complex_trk_filename, lazy_load=True) + + hdr = complex_trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_trk), len(self.points)) + + points = complex_trk.points + assert_arrays_equal(points, self.points) + + scalars = complex_trk.scalars + assert_arrays_equal(scalars, self.colors) + + properties = complex_trk.properties + assert_arrays_equal(properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_trk: + pass + + def test_write_simple_file(self): + streamlines = Streamlines(self.points) + + simple_trk_file = BytesIO() + TrkFile.save(streamlines, simple_trk_file) + + simple_trk_file.seek(0, os.SEEK_SET) + + simple_trk = nib.streamlines.load(simple_trk_file, lazy_load=False) + + hdr = simple_trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(simple_trk), len(self.points)) + + points = simple_trk.points + assert_arrays_equal(points, self.points) + + scalars = simple_trk.scalars + assert_equal(len(scalars), len(self.points)) + + properties = simple_trk.properties + assert_equal(len(properties), len(self.points)) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in simple_trk: + pass + + def test_write_complex_file(self): + # With scalars + streamlines = Streamlines(self.points, scalars=self.colors) + + trk_file = BytesIO() + TrkFile.save(streamlines, trk_file) + + trk_file.seek(0, os.SEEK_SET) + + trk = nib.streamlines.load(trk_file, lazy_load=False) + + hdr = trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(trk), len(self.points)) + + points = trk.points + assert_arrays_equal(points, self.points) + + scalars = trk.scalars + assert_arrays_equal(scalars, self.colors) + + properties = trk.properties + assert_equal(len(properties), len(self.points)) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in trk: + pass + + # With properties + streamlines = Streamlines(self.points, properties=self.mean_curvature_torsion) + + trk_file = BytesIO() + TrkFile.save(streamlines, trk_file) + + trk_file.seek(0, os.SEEK_SET) + + trk = nib.streamlines.load(trk_file, lazy_load=False) + + hdr = trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(trk), len(self.points)) + + points = trk.points + assert_arrays_equal(points, self.points) + + scalars = trk.scalars + assert_equal(len(scalars), len(self.points)) + + properties = trk.properties + assert_arrays_equal(properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in trk: + pass + + # With scalars and properties + streamlines = Streamlines(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) + + trk_file = BytesIO() + TrkFile.save(streamlines, trk_file) + + trk_file.seek(0, os.SEEK_SET) + + trk = nib.streamlines.load(trk_file, lazy_load=False) + + hdr = trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(trk), len(self.points)) + + points = trk.points + assert_arrays_equal(points, self.points) + + scalars = trk.scalars + assert_arrays_equal(scalars, self.colors) + + properties = trk.properties + assert_arrays_equal(properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in trk: + pass + + def test_write_erroneous_file(self): + # No scalars for every points + scalars = [[(1, 0, 0)], + [(0, 1, 0)], + [(0, 0, 1)]] + + streamlines = Streamlines(self.points, scalars) + assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + # No scalars for every streamlines + scalars = [[(1, 0, 0)]*1, + [(0, 1, 0)]*2] + + streamlines = Streamlines(self.points, scalars) + assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + # Inconsistent number of scalars between points + scalars = [[(1, 0, 0)]*1, + [(0, 1, 0), (0, 1)], + [(0, 0, 1)]*5] + + streamlines = Streamlines(self.points, scalars) + assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) + + # Inconsistent number of scalars between streamlines + scalars = [[(1, 0, 0)]*1, + [(0, 1)]*2, + [(0, 0, 1)]*5] + + streamlines = Streamlines(self.points, scalars) + assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + # Inconsistent number of properties + properties = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + streamlines = Streamlines(self.points, properties=properties) + assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + # No properties for every streamlines + properties = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11, 2.22], dtype="f4")] + streamlines = Streamlines(self.points, properties=properties) + assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + def test_write_file_from_generator(self): + gen_points = (point for point in self.points) + gen_scalars = (scalar for scalar in self.colors) + gen_properties = (prop for prop in self.mean_curvature_torsion) + + assert_raises(HeaderError, Streamlines, points=gen_points, scalars=gen_scalars, properties=gen_properties) + + hdr = {Field.NB_STREAMLINES: len(self.points)} + streamlines = Streamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties, hdr=hdr) + + trk_file = BytesIO() + TrkFile.save(streamlines, trk_file) + + trk_file.seek(0, os.SEEK_SET) + + trk = nib.streamlines.load(trk_file, lazy_load=False) + + hdr = trk.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(trk), len(self.points)) + + points = trk.points + assert_arrays_equal(points, self.points) + + scalars = trk.scalars + assert_arrays_equal(scalars, self.colors) + properties = trk.properties + assert_arrays_equal(properties, self.mean_curvature_torsion) -def test_load_file(): - # Test loading empty file - # empty_file = pjoin(DATA_PATH, "empty.trk") - # empty_trk = nib.streamlines.load(empty_file) - - # hdr = empty_trk.get_header() - # points = empty_trk.get_points(as_generator=False) - # scalars = empty_trk.get_scalars(as_generator=False) - # properties = empty_trk.get_properties(as_generator=False) - - # assert_equal(hdr[Field.NB_STREAMLINES], 0) - # assert_equal(len(points), 0) - # assert_equal(len(scalars), 0) - # assert_equal(len(properties), 0) - - # for i in empty_trk: pass # Check if we can iterate through the streamlines. - - # Test loading non-empty file - trk_file = pjoin(DATA_PATH, "uncinate.trk") - trk = nib.streamlines.load(trk_file) - - hdr = trk.get_header() - points = trk.get_points(as_generator=False) - 1/0 - scalars = trk.get_scalars(as_generator=False) - properties = trk.get_properties(as_generator=False) - - assert_equal(hdr[Field.NB_STREAMLINES] > 0, True) - assert_equal(len(points) > 0, True) - #assert_equal(len(scalars), 0) - #assert_equal(len(properties), 0) - - for i in trk: pass # Check if we can iterate through the streamlines. - - -if __name__ == "__main__": - test_load_file() \ No newline at end of file + # Check if we can iterate through the streamlines. + for point, scalar, prop in trk: + pass diff --git a/nibabel/streamlines/tests/test_utils.py b/nibabel/streamlines/tests/test_utils.py new file mode 100644 index 0000000000..a0dbcd8b61 --- /dev/null +++ b/nibabel/streamlines/tests/test_utils.py @@ -0,0 +1,323 @@ +import os +import unittest +import tempfile +import numpy as np + +from os.path import join as pjoin + +from nibabel.externals.six import BytesIO + +from nibabel.testing import assert_arrays_equal +from nose.tools import assert_equal, assert_raises, assert_true, assert_false + +import nibabel.streamlines.utils as streamline_utils + +from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import HeaderError +from nibabel.streamlines.header import Field + +DATA_PATH = pjoin(os.path.dirname(__file__), 'data') + + +def test_is_supported(): + # Emtpy file/string + f = BytesIO() + assert_false(streamline_utils.is_supported(f)) + assert_false(streamline_utils.is_supported("")) + + # Valid file without extension + for streamlines_file in streamline_utils.FORMATS.values(): + f = BytesIO() + f.write(streamlines_file.get_magic_number()) + f.seek(0, os.SEEK_SET) + assert_true(streamline_utils.is_supported(f)) + + # Wrong extension but right magic number + for streamlines_file in streamline_utils.FORMATS.values(): + with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: + f.write(streamlines_file.get_magic_number()) + f.seek(0, os.SEEK_SET) + assert_true(streamline_utils.is_supported(f)) + + # Good extension but wrong magic number + for ext, streamlines_file in streamline_utils.FORMATS.items(): + with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: + f.write(b"pass") + f.seek(0, os.SEEK_SET) + assert_false(streamline_utils.is_supported(f)) + + # Wrong extension, string only + f = "my_streamlines.asd" + assert_false(streamline_utils.is_supported(f)) + + # Good extension, string only + for ext, streamlines_file in streamline_utils.FORMATS.items(): + f = "my_streamlines" + ext + assert_true(streamline_utils.is_supported(f)) + + +def test_detect_format(): + # Emtpy file/string + f = BytesIO() + assert_equal(streamline_utils.detect_format(f), None) + assert_equal(streamline_utils.detect_format(""), None) + + # Valid file without extension + for streamlines_file in streamline_utils.FORMATS.values(): + f = BytesIO() + f.write(streamlines_file.get_magic_number()) + f.seek(0, os.SEEK_SET) + assert_equal(streamline_utils.detect_format(f), streamlines_file) + + # Wrong extension but right magic number + for streamlines_file in streamline_utils.FORMATS.values(): + with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: + f.write(streamlines_file.get_magic_number()) + f.seek(0, os.SEEK_SET) + assert_equal(streamline_utils.detect_format(f), streamlines_file) + + # Good extension but wrong magic number + for ext, streamlines_file in streamline_utils.FORMATS.items(): + with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: + f.write(b"pass") + f.seek(0, os.SEEK_SET) + assert_equal(streamline_utils.detect_format(f), None) + + # Wrong extension, string only + f = "my_streamlines.asd" + assert_equal(streamline_utils.detect_format(f), None) + + # Good extension, string only + for ext, streamlines_file in streamline_utils.FORMATS.items(): + f = "my_streamlines" + ext + assert_equal(streamline_utils.detect_format(f), streamlines_file) + + +class TestLoadSave(unittest.TestCase): + # Testing scalars and properties depend on the format. + # See unit tests in the specific format test file. + + def setUp(self): + self.empty_filenames = [pjoin(DATA_PATH, "empty" + ext) for ext in streamline_utils.FORMATS.keys()] + self.simple_filenames = [pjoin(DATA_PATH, "simple" + ext) for ext in streamline_utils.FORMATS.keys()] + self.complex_filenames = [pjoin(DATA_PATH, "complex" + ext) for ext in streamline_utils.FORMATS.keys()] + + self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11, 2.22], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + + def test_load_empty_file(self): + for empty_filename in self.empty_filenames: + empty_streamlines = streamline_utils.load(empty_filename, lazy_load=False) + + hdr = empty_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], 0) + assert_equal(len(empty_streamlines), 0) + + points = empty_streamlines.points + assert_equal(len(points), 0) + + # For an empty file, scalars should be zero regardless of the format. + scalars = empty_streamlines.scalars + assert_equal(len(scalars), 0) + + # For an empty file, properties should be zero regardless of the format. + properties = empty_streamlines.properties + assert_equal(len(properties), 0) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in empty_streamlines: + pass + + def test_load_simple_file(self): + for simple_filename in self.simple_filenames: + simple_streamlines = streamline_utils.load(simple_filename, lazy_load=False) + + hdr = simple_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(simple_streamlines), len(self.points)) + + points = simple_streamlines.points + assert_arrays_equal(points, self.points) + + # Testing scalars and properties depend on the format. + # See unit tests in the specific format test file. + + # Check if we can iterate through the streamlines. + for point, scalar, prop in simple_streamlines: + pass + + # Test lazy_load + simple_streamlines = streamline_utils.load(simple_filename, lazy_load=True) + + hdr = simple_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(simple_streamlines), len(self.points)) + + points = simple_streamlines.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in simple_streamlines: + pass + + def test_load_complex_file(self): + for complex_filename in self.complex_filenames: + complex_streamlines = streamline_utils.load(complex_filename, lazy_load=False) + + hdr = complex_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_streamlines), len(self.points)) + + points = complex_streamlines.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_streamlines: + pass + + complex_streamlines = streamline_utils.load(complex_filename, lazy_load=True) + + hdr = complex_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_streamlines), len(self.points)) + + points = complex_streamlines.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_streamlines: + pass + + def test_save_simple_file(self): + for ext in streamline_utils.FORMATS.keys(): + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + streamlines = Streamlines(self.points) + + streamline_utils.save(streamlines, f.name) + simple_streamlines = streamline_utils.load(f, lazy_load=False) + + hdr = simple_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(simple_streamlines), len(self.points)) + + points = simple_streamlines.points + assert_arrays_equal(points, self.points) + + def test_save_complex_file(self): + for ext in streamline_utils.FORMATS.keys(): + # With scalars + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + streamlines = Streamlines(self.points, scalars=self.colors) + + streamline_utils.save(streamlines, f.name) + complex_streamlines = streamline_utils.load(f, lazy_load=False) + + hdr = complex_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_streamlines), len(self.points)) + + points = complex_streamlines.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_streamlines: + pass + + # With properties + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + streamlines = Streamlines(self.points, properties=self.mean_curvature_torsion) + + streamline_utils.save(streamlines, f.name) + complex_streamlines = streamline_utils.load(f, lazy_load=False) + + hdr = complex_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_streamlines), len(self.points)) + + points = complex_streamlines.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_streamlines: + pass + + # With scalars and properties + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + streamlines = Streamlines(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) + + streamline_utils.save(streamlines, f.name) + complex_streamlines = streamline_utils.load(f, lazy_load=False) + + hdr = complex_streamlines.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(complex_streamlines), len(self.points)) + + points = complex_streamlines.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in complex_streamlines: + pass + + def test_save_file_from_generator(self): + for ext in streamline_utils.FORMATS.keys(): + # With scalars + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + gen_points = (point for point in self.points) + gen_scalars = (scalar for scalar in self.colors) + gen_properties = (prop for prop in self.mean_curvature_torsion) + + assert_raises(HeaderError, Streamlines, points=gen_points, scalars=gen_scalars, properties=gen_properties) + + hdr = {Field.NB_STREAMLINES: len(self.points)} + streamlines = Streamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties, hdr=hdr) + + streamline_utils.save(streamlines, f.name) + streamlines_loaded = streamline_utils.load(f, lazy_load=False) + + hdr = streamlines_loaded.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(streamlines_loaded), len(self.points)) + + points = streamlines_loaded.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in streamlines_loaded: + pass + + def test_save_file_from_function(self): + for ext in streamline_utils.FORMATS.keys(): + # With scalars + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + gen_points = lambda: (point for point in self.points) + gen_scalars = lambda: (scalar for scalar in self.colors) + gen_properties = lambda: (prop for prop in self.mean_curvature_torsion) + + assert_raises(HeaderError, Streamlines, points=gen_points, scalars=gen_scalars, properties=gen_properties) + + hdr = {Field.NB_STREAMLINES: len(self.points)} + streamlines = Streamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties, hdr=hdr) + + streamline_utils.save(streamlines, f.name) + streamlines_loaded = streamline_utils.load(f, lazy_load=False) + + hdr = streamlines_loaded.get_header() + assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) + assert_equal(len(streamlines_loaded), len(self.points)) + + points = streamlines_loaded.points + assert_arrays_equal(points, self.points) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in streamlines_loaded: + pass diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 6a02397cf6..c256bcc0ec 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -1,107 +1,92 @@ +from __future__ import division + # Documentation available here: # http://www.trackvis.org/docs/?subsect=fileformat -from pdb import set_trace as dbg +from ..externals.six.moves import xrange +import struct import os import warnings + import numpy as np -from numpy.lib.recfunctions import append_fields from nibabel.openers import Opener -from nibabel.volumeutils import (native_code, swapped_code, endian_codes) +from nibabel.volumeutils import (native_code, swapped_code) -from nibabel.streamlines.base_format import DynamicStreamlineFile +from nibabel.streamlines.base_format import Streamlines, StreamlinesFile from nibabel.streamlines.header import Field +from nibabel.streamlines.base_format import DataError, HeaderError # Definition of trackvis header structure. # See http://www.trackvis.org/docs/?subsect=fileformat # See http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html -header_1_dtd = [ - (Field.MAGIC_NUMBER, 'S6'), - (Field.DIMENSIONS, 'h', 3), - (Field.VOXEL_SIZES, 'f4', 3), - (Field.ORIGIN, 'f4', 3), - (Field.NB_SCALARS_PER_POINT, 'h'), - ('scalar_name', 'S20', 10), - (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), - ('property_name', 'S20', 10), - ('reserved', 'S508'), - (Field.VOXEL_ORDER, 'S4'), - ('pad2', 'S4'), - ('image_orientation_patient', 'f4', 6), - ('pad1', 'S2'), - ('invert_x', 'S1'), - ('invert_y', 'S1'), - ('invert_z', 'S1'), - ('swap_xy', 'S1'), - ('swap_yz', 'S1'), - ('swap_zx', 'S1'), - (Field.NB_STREAMLINES, 'i4'), - ('version', 'i4'), - ('hdr_size', 'i4'), - ] +header_1_dtd = [(Field.MAGIC_NUMBER, 'S6'), + (Field.DIMENSIONS, 'h', 3), + (Field.VOXEL_SIZES, 'f4', 3), + (Field.ORIGIN, 'f4', 3), + (Field.NB_SCALARS_PER_POINT, 'h'), + ('scalar_name', 'S20', 10), + (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), + ('property_name', 'S20', 10), + ('reserved', 'S508'), + (Field.VOXEL_ORDER, 'S4'), + ('pad2', 'S4'), + ('image_orientation_patient', 'f4', 6), + ('pad1', 'S2'), + ('invert_x', 'S1'), + ('invert_y', 'S1'), + ('invert_z', 'S1'), + ('swap_xy', 'S1'), + ('swap_yz', 'S1'), + ('swap_zx', 'S1'), + (Field.NB_STREAMLINES, 'i4'), + ('version', 'i4'), + ('hdr_size', 'i4'), + ] # Version 2 adds a 4x4 matrix giving the affine transformtation going # from voxel coordinates in the referenced 3D voxel matrix, to xyz # coordinates (axes L->R, P->A, I->S). IF (0 based) value [3, 3] from # this matrix is 0, this means the matrix is not recorded. -header_2_dtd = [ - (Field.MAGIC_NUMBER, 'S6'), - (Field.DIMENSIONS, 'h', 3), - (Field.VOXEL_SIZES, 'f4', 3), - (Field.ORIGIN, 'f4', 3), - (Field.NB_SCALARS_PER_POINT, 'h'), - ('scalar_name', 'S20', 10), - (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), - ('property_name', 'S20', 10), - (Field.VOXEL_TO_WORLD, 'f4', (4,4)), # new field for version 2 - ('reserved', 'S444'), - (Field.VOXEL_ORDER, 'S4'), - ('pad2', 'S4'), - ('image_orientation_patient', 'f4', 6), - ('pad1', 'S2'), - ('invert_x', 'S1'), - ('invert_y', 'S1'), - ('invert_z', 'S1'), - ('swap_xy', 'S1'), - ('swap_yz', 'S1'), - ('swap_zx', 'S1'), - (Field.NB_STREAMLINES, 'i4'), - ('version', 'i4'), - ('hdr_size', 'i4'), - ] +header_2_dtd = [(Field.MAGIC_NUMBER, 'S6'), + (Field.DIMENSIONS, 'h', 3), + (Field.VOXEL_SIZES, 'f4', 3), + (Field.ORIGIN, 'f4', 3), + (Field.NB_SCALARS_PER_POINT, 'h'), + ('scalar_name', 'S20', 10), + (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), + ('property_name', 'S20', 10), + (Field.VOXEL_TO_WORLD, 'f4', (4, 4)), # new field for version 2 + ('reserved', 'S444'), + (Field.VOXEL_ORDER, 'S4'), + ('pad2', 'S4'), + ('image_orientation_patient', 'f4', 6), + ('pad1', 'S2'), + ('invert_x', 'S1'), + ('invert_y', 'S1'), + ('invert_z', 'S1'), + ('swap_xy', 'S1'), + ('swap_yz', 'S1'), + ('swap_zx', 'S1'), + (Field.NB_STREAMLINES, 'i4'), + ('version', 'i4'), + ('hdr_size', 'i4'), + ] # Full header numpy dtypes header_1_dtype = np.dtype(header_1_dtd) header_2_dtype = np.dtype(header_2_dtd) -class HeaderError(Exception): - pass - - -class DataError(Exception): - pass - - -class TrkFile(DynamicStreamlineFile): - MAGIC_NUMBER = "TRACK" - OFFSET = 1000 +class TrkFile(StreamlinesFile): + ''' Convenience class to encapsulate TRK format. ''' - def __init__(self, hdr, streamlines, scalars, properties): - self.filename = None - - self.hdr = hdr - self.streamlines = streamlines - self.scalars = scalars - self.properties = properties + MAGIC_NUMBER = b"TRACK" + HEADER_SIZE = 1000 - ##### - # Static Methods - ### @classmethod def get_magic_number(cls): - ''' Return TRK's magic number ''' + ''' Return TRK's magic number. ''' return cls.MAGIC_NUMBER @classmethod @@ -111,354 +96,355 @@ def is_correct_format(cls, fileobj): Parameters ---------- fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to TRK file (and ready to read from the beginning - of the TRK header data) + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data) Returns ------- is_correct_format : boolean - Returns True if `fileobj` is in TRK format, False otherwise. + Returns True if `fileobj` is in TRK format. ''' - with Opener(fileobj) as fileobj: - magic_number = fileobj.read(5) - fileobj.seek(-5, os.SEEK_CUR) + with Opener(fileobj) as f: + magic_number = f.read(5) + f.seek(-5, os.SEEK_CUR) return magic_number == cls.MAGIC_NUMBER return False @classmethod - def load(cls, fileobj): - hdr = {} - pos_header = 0 - pos_data = 0 + def sanity_check(cls, fileobj): + ''' Check if data is consistent with information contained in the header. + [Might be useful] - with Opener(fileobj) as fileobj: - pos_header = fileobj.tell() + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data) + + Returns + ------- + is_consistent : boolean + Returns True if data is consistent with header, False otherwise. + ''' + is_consistent = True + + with Opener(fileobj) as f: + start_position = f.tell() - ##### # Read header - ### - hdr_str = fileobj.read(header_2_dtype.itemsize) - hdr = np.fromstring(string=hdr_str, dtype=header_2_dtype) - - if hdr['version'] == 1: - hdr = np.fromstring(string=hdr_str, dtype=header_1_dtype) - elif hdr['version'] == 2: - pass # Nothing more to do here + hdr_str = f.read(header_2_dtype.itemsize) + hdr_rec = np.fromstring(string=hdr_str, dtype=header_2_dtype) + + if hdr_rec['version'] == 1: + hdr_rec = np.fromstring(string=hdr_str, dtype=header_1_dtype) + elif hdr_rec['version'] == 2: + pass # Nothing more to do here else: - raise HeaderError('NiBabel only supports versions 1 and 2.') + warnings.warn("NiBabel only supports versions 1 and 2 (not v.{0}).".format(hdr_rec['version'])) + os.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + return False - # Make header a dictionnary instead of ndarray - hdr = dict(zip(hdr.dtype.names, hdr[0])) + # Convert the first record of `hdr_rec` into a dictionnary + hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0])) # Check endianness - #hdr = append_fields(hdr, Field.ENDIAN, [native_code], usemask=False) hdr[Field.ENDIAN] = native_code - if hdr['hdr_size'] != 1000: + if hdr['hdr_size'] != cls.HEADER_SIZE: hdr[Field.ENDIAN] = swapped_code - hdr = hdr.newbyteorder() - if hdr['hdr_size'] != 1000: - raise HeaderError('Invalid hdr_size of {0}'.format(hdr['hdr_size'])) - - # Add more header fields implied by trk format. - #hdr = append_fields(hdr, Field.WORLD_ORDER, ["RAS"], usemask=False) - hdr[Field.WORLD_ORDER] = "RAS" - - pos_data = fileobj.tell() + hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0].newbyteorder())) # Swap byte order + if hdr['hdr_size'] != cls.HEADER_SIZE: + warnings.warn("Invalid hdr_size: {0} instead of {1}".format(hdr['hdr_size'], cls.HEADER_SIZE)) + os.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + return False + + # By default, the voxel order is LPS. + # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates + if hdr[Field.VOXEL_ORDER] == "": + is_consistent = False + warnings.warn("Voxel order is not specified, will assume 'LPS' since it is Trackvis software's default.") + # Add more header fields implied by trk format. i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") - nb_streamlines = 0 + pts_and_scalars_size = (3 + hdr[Field.NB_SCALARS_PER_POINT]) * f4_dtype.itemsize + properties_size = hdr[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize - #Either verify the number of streamlines specified in the header is correct or - # count the actual number of streamlines in case it was not specified in the header. + #Verify the number of streamlines specified in the header is correct. + nb_streamlines = 0 while True: # Read number of points of the streamline - buf = fileobj.read(i4_dtype.itemsize) + buf = f.read(i4_dtype.itemsize) if buf == '': - break # EOF + break # EOF - nb_pts = np.fromstring(buf, - dtype=i4_dtype, - count=1) + nb_pts = struct.unpack(i4_dtype.str[:-1], buf)[0] - bytes_to_skip = nb_pts * 3 # x, y, z coordinates - bytes_to_skip += nb_pts * hdr[Field.NB_SCALARS_PER_POINT] - bytes_to_skip += hdr[Field.NB_PROPERTIES_PER_STREAMLINE] + bytes_to_skip = nb_pts * pts_and_scalars_size + bytes_to_skip += properties_size # Seek to the next streamline in the file. - fileobj.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) + f.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) nb_streamlines += 1 if hdr[Field.NB_STREAMLINES] != nb_streamlines: - warnings.warn('The number of streamlines specified in header ({1}) does not match ' + - 'the actual number of streamlines contained in this file ({1}). ' + - 'The latter will be used.'.format(hdr[Field.NB_STREAMLINES], nb_streamlines)) + is_consistent = False + warnings.warn(('The number of streamlines specified in header ({1}) does not match ' + 'the actual number of streamlines contained in this file ({1}). ' + ).format(hdr[Field.NB_STREAMLINES], nb_streamlines)) - hdr[Field.NB_STREAMLINES] = nb_streamlines + os.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + + return is_consistent + + @classmethod + def load(cls, fileobj, lazy_load=True): + ''' Loads streamlines from a file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header) - trk_file = cls(hdr, [], [], []) - trk_file.pos_header = pos_header - trk_file.pos_data = pos_data - trk_file.streamlines + lazy_load : boolean + Load streamlines in a lazy manner i.e. they will not be kept + in memory. For postprocessing speed, turn off this option. - return trk_file - # cls(hdr, streamlines, scalars, properties) + Returns + ------- + streamlines : Streamlines object + Returns an object containing streamlines' data and header + information. See 'nibabel.Streamlines'. + ''' + hdr = {} - def get_header(self): - return self.hdr + with Opener(fileobj) as f: + hdr['pos_header'] = f.tell() - def get_points(self, as_generator=False): - self.fileobj.seek(self.pos_data, os.SEEK_SET) - pos = self.pos_data + # Read header + hdr_str = f.read(header_2_dtype.itemsize) + hdr_rec = np.fromstring(string=hdr_str, dtype=header_2_dtype) - i4_dtype = np.dtype(self.hdr[Field.ENDIAN] + "i4") - f4_dtype = np.dtype(self.hdr[Field.ENDIAN] + "f4") + if hdr_rec['version'] == 1: + hdr_rec = np.fromstring(string=hdr_str, dtype=header_1_dtype) + elif hdr_rec['version'] == 2: + pass # Nothing more to do here + else: + raise HeaderError('NiBabel only supports versions 1 and 2.') - for i in range(self.hdr[Field.NB_STREAMLINES]): - # Read number of points of the streamline - nb_pts = np.fromstring(self.fileobj.read(i4_dtype.itemsize), - dtype=i4_dtype, - count=1) - - # Read points of the streamline - pts = np.fromstring(self.fileobj.read(nb_pts * 3 * i4_dtype.itemsize), - dtype=[f4_dtype, f4_dtype, f4_dtype], - count=nb_pts) + # Convert the first record of `hdr_rec` into a dictionnary + hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0])) - pos = self.fileobj.tell() - yield pts - self.fileobj.seek(pos, os.SEEK_SET) - - bytes_to_skip = nb_pts * self.hdr[Field.NB_SCALARS_PER_POINT] - bytes_to_skip += self.hdr[Field.NB_PROPERTIES_PER_STREAMLINE] - - # Seek to the next streamline in the file. - self.fileobj.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) - - ##### - # Methods - ### - - - - -# import os -# import logging -# import numpy as np - -# from tractconverter.formats.header import Header as H - - -# def readBinaryBytes(f, nbBytes, dtype): -# buff = f.read(nbBytes * dtype.itemsize) -# return np.frombuffer(buff, dtype=dtype) - - -# class TRK: -# # self.hdr -# # self.filename -# # self.hdr[H.ENDIAN] -# # self.FIBER_DELIMITER -# # self.END_DELIMITER - -# @staticmethod -# def create(filename, hdr, anatFile=None): -# f = open(filename, 'wb') -# f.write(TRK.MAGIC_NUMBER + "\n") -# f.close() - -# trk = TRK(filename, load=False) -# trk.hdr = hdr -# trk.writeHeader() - -# return trk - -# ##### -# # Methods -# ### -# def __init__(self, filename, anatFile=None, load=True): -# if not TRK._check(filename): -# raise NameError("Not a TRK file.") - -# self.filename = filename -# self.hdr = {} -# if load: -# self._load() - -# def _load(self): -# f = open(self.filename, 'rb') - -# ##### -# # Read header -# ### -# self.hdr[H.MAGIC_NUMBER] = f.read(6) -# self.hdr[H.DIMENSIONS] = np.frombuffer(f.read(6), dtype='i4') -# self.hdr["version"] = self.hdr["version"].astype('>i4') -# self.hdr["hdr_size"] = self.hdr["hdr_size"].astype('>i4') - -# nb_fibers = 0 -# self.hdr[H.NB_POINTS] = 0 - -# #Either verify the number of streamlines specified in the header is correct or -# # count the actual number of streamlines in case it is not specified in the header. -# remainingBytes = os.path.getsize(self.filename) - self.OFFSET -# while remainingBytes > 0: -# # Read points -# nbPoints = readBinaryBytes(f, 1, np.dtype(self.hdr[H.ENDIAN] + "i4"))[0] -# self.hdr[H.NB_POINTS] += nbPoints -# # This seek is used to go to the next points number indication in the file. -# f.seek((nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]) -# + self.hdr[H.NB_PROPERTIES_PER_STREAMLINE]) * 4, 1) # Relative seek -# remainingBytes -= (nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]) -# + self.hdr[H.NB_PROPERTIES_PER_STREAMLINE]) * 4 + 4 -# nb_fibers += 1 - -# if self.hdr[H.NB_FIBERS] != nb_fibers: -# logging.warn('The number of streamlines specified in header ({1}) does not match ' + -# 'the actual number of streamlines contained in this file ({1}). ' + -# 'The latter will be used.'.format(self.hdr[H.NB_FIBERS], nb_fibers)) - -# self.hdr[H.NB_FIBERS] = nb_fibers - -# f.close() - -# def writeHeader(self): -# # Get the voxel size and format it as an array. -# voxel_sizes = np.asarray(self.hdr.get(H.VOXEL_SIZES, (1.0, 1.0, 1.0)), dtype=' 0: -# # Read points -# nbPoints = readBinaryBytes(f, 1, np.dtype(self.hdr[H.ENDIAN] + "i4"))[0] -# ptsAndScalars = readBinaryBytes(f, -# nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]), -# np.dtype(self.hdr[H.ENDIAN] + "f4")) - -# newShape = [-1, 3 + self.hdr[H.NB_SCALARS_PER_POINT]] -# ptsAndScalars = ptsAndScalars.reshape(newShape) - -# pointsWithoutScalars = ptsAndScalars[:, 0:3] -# yield pointsWithoutScalars - -# # For now, we do not process the tract properties, so just skip over them. -# remainingBytes -= nbPoints * (3 + self.hdr[H.NB_SCALARS_PER_POINT]) * 4 + 4 -# remainingBytes -= self.hdr[H.NB_PROPERTIES_PER_STREAMLINE] * 4 -# cpt += 1 - -# f.close() - -# def __str__(self): -# text = "" -# text += "MAGIC NUMBER: {0}".format(self.hdr[H.MAGIC_NUMBER]) -# text += "v.{0}".format(self.hdr['version']) -# text += "dim: {0}".format(self.hdr[H.DIMENSIONS]) -# text += "voxel_sizes: {0}".format(self.hdr[H.VOXEL_SIZES]) -# text += "orgin: {0}".format(self.hdr[H.ORIGIN]) -# text += "nb_scalars: {0}".format(self.hdr[H.NB_SCALARS_PER_POINT]) -# text += "scalar_name:\n {0}".format("\n".join(self.hdr['scalar_name'])) -# text += "nb_properties: {0}".format(self.hdr[H.NB_PROPERTIES_PER_STREAMLINE]) -# text += "property_name:\n {0}".format("\n".join(self.hdr['property_name'])) -# text += "vox_to_world: {0}".format(self.hdr[H.VOXEL_TO_WORLD]) -# text += "world_order: {0}".format(self.hdr[H.WORLD_ORDER]) -# text += "voxel_order: {0}".format(self.hdr[H.VOXEL_ORDER]) -# text += "image_orientation_patient: {0}".format(self.hdr['image_orientation_patient']) -# text += "pad1: {0}".format(self.hdr['pad1']) -# text += "pad2: {0}".format(self.hdr['pad2']) -# text += "invert_x: {0}".format(self.hdr['invert_x']) -# text += "invert_y: {0}".format(self.hdr['invert_y']) -# text += "invert_z: {0}".format(self.hdr['invert_z']) -# text += "swap_xy: {0}".format(self.hdr['swap_xy']) -# text += "swap_yz: {0}".format(self.hdr['swap_yz']) -# text += "swap_zx: {0}".format(self.hdr['swap_zx']) -# text += "n_count: {0}".format(self.hdr[H.NB_FIBERS]) -# text += "hdr_size: {0}".format(self.hdr['hdr_size']) -# text += "endianess: {0}".format(self.hdr[H.ENDIAN]) - -# return text + # Check endianness + hdr[Field.ENDIAN] = native_code + if hdr['hdr_size'] != cls.HEADER_SIZE: + hdr[Field.ENDIAN] = swapped_code + hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0].newbyteorder())) # Swap byte order + if hdr['hdr_size'] != cls.HEADER_SIZE: + raise HeaderError('Invalid hdr_size: {0} instead of {1}'.format(hdr['hdr_size'], cls.HEADER_SIZE)) + + # By default, the voxel order is LPS. + # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates + if hdr[Field.VOXEL_ORDER] == "": + hdr[Field.VOXEL_ORDER] = "LPS" + + # Add more header fields implied by trk format. + hdr[Field.WORLD_ORDER] = "RAS" + hdr['pos_data'] = f.tell() + + points = lambda: (x[0] for x in TrkFile._read_data(hdr, fileobj)) + scalars = lambda: (x[1] for x in TrkFile._read_data(hdr, fileobj)) + properties = lambda: (x[2] for x in TrkFile._read_data(hdr, fileobj)) + data = lambda: TrkFile._read_data(hdr, fileobj) + + if lazy_load: + streamlines = Streamlines(points, scalars, properties, hdr=hdr) + streamlines.data = data + return streamlines + + return Streamlines(*zip(*data()), hdr=hdr) + + @staticmethod + def _read_data(hdr, fileobj): + ''' Read streamlines' data from a file-like object using a TRK's header. ''' + i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") + f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") + + with Opener(fileobj) as f: + nb_pts_and_scalars = 3 + int(hdr[Field.NB_SCALARS_PER_POINT]) + pts_and_scalars_size = nb_pts_and_scalars * f4_dtype.itemsize + + slice_pts_and_scalars = lambda data: (data, []) + if hdr[Field.NB_SCALARS_PER_POINT] > 0: + # This is faster than np.split + slice_pts_and_scalars = lambda data: (data[:, :3], data[:, 3:]) + + # Using np.fromfile would be faster, but does not support StringIO + read_pts_and_scalars = lambda nb_pts: slice_pts_and_scalars(np.ndarray(shape=(nb_pts, nb_pts_and_scalars), + dtype=f4_dtype, + buffer=f.read(nb_pts * pts_and_scalars_size))) + + properties_size = int(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) * f4_dtype.itemsize + read_properties = lambda: [] + if hdr[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + read_properties = lambda: np.fromstring(f.read(properties_size), + dtype=f4_dtype, + count=hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + + f.seek(hdr['pos_data'], os.SEEK_SET) + + for i in xrange(hdr[Field.NB_STREAMLINES]): + # Read number of points of the next streamline. + nb_pts = struct.unpack(i4_dtype.str[:-1], f.read(i4_dtype.itemsize))[0] + + # Read streamline's data + pts, scalars = read_pts_and_scalars(nb_pts) + properties = read_properties() + yield pts, scalars, properties + + @classmethod + def get_empty_header(cls): + ''' Return an empty TRK's header. ''' + hdr = np.zeros(1, dtype=header_2_dtype) + + #Default values + hdr[Field.MAGIC_NUMBER] = cls.MAGIC_NUMBER + hdr[Field.VOXEL_SIZES] = (1, 1, 1) + hdr[Field.DIMENSIONS] = (1, 1, 1) + hdr[Field.VOXEL_TO_WORLD] = np.eye(4) + hdr['version'] = 2 + hdr['hdr_size'] = cls.HEADER_SIZE + + return hdr + + @classmethod + def save(cls, streamlines, fileobj): + ''' Saves streamlines to a file-like object. + + Parameters + ---------- + streamlines : Streamlines object + Object containing streamlines' data and header information. + See 'nibabel.Streamlines'. + + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data) + ''' + hdr = cls.get_empty_header() + + #Override hdr's fields by those contain in `streamlines`'s header + for k, v in streamlines.get_header().items(): + if k in header_2_dtype.fields.keys(): + hdr[k] = v + + # Check which endianess to use to write data. + endianess = streamlines.get_header().get(Field.ENDIAN, native_code) + + if endianess == swapped_code: + hdr = hdr.newbyteorder() + + i4_dtype = np.dtype(endianess + "i4") + f4_dtype = np.dtype(endianess + "f4") + + # Keep counts for correcting incoherent fields or warn. + nb_streamlines = 0 + nb_points = 0 + nb_scalars = 0 + nb_properties = 0 + + # Write header + data of streamlines + with Opener(fileobj, mode="wb") as f: + pos = f.tell() + f.write(hdr[0].tostring()) + + for points, scalars, properties in streamlines: + if len(scalars) > 0 and len(scalars) != len(points): + raise DataError("Missing scalars for some points!") + + points = np.array(points, dtype=f4_dtype) + scalars = np.array(scalars, dtype=f4_dtype).reshape((len(points), -1)) + properties = np.array(properties, dtype=f4_dtype) + + data = struct.pack(i4_dtype.str[:-1], len(points)) + data += np.concatenate((points, scalars), axis=1).tostring() + data += properties.tostring() + f.write(data) + + nb_streamlines += 1 + nb_points += len(points) + nb_scalars += scalars.size + nb_properties += len(properties) + + # Either correct or warn if header and data are incoherent. + #TODO: add a warn option as a function parameter + nb_scalars_per_point = nb_scalars / nb_points + nb_properties_per_streamline = nb_properties / nb_streamlines + + # Check for errors + if nb_scalars_per_point != int(nb_scalars_per_point): + raise DataError("Nb. of scalars differs from one point to another!") + + if nb_properties_per_streamline != int(nb_properties_per_streamline): + raise DataError("Nb. of properties differs from one streamline to another!") + + hdr[Field.NB_STREAMLINES] = nb_streamlines + hdr[Field.NB_SCALARS_PER_POINT] = nb_scalars_per_point + hdr[Field.NB_PROPERTIES_PER_STREAMLINE] = nb_properties_per_streamline + + f.seek(pos, os.SEEK_SET) + f.write(hdr[0].tostring()) # Overwrite header with updated one. + + @staticmethod + def pretty_print(streamlines): + ''' Gets a formatted string contaning header's information + relevant to the TRK format. + + Parameters + ---------- + streamlines : Streamlines object + Object containing streamlines' data and header information. + See 'nibabel.Streamlines'. + + Returns + ------- + info : string + Header's information relevant to the TRK format. + ''' + hdr = streamlines.get_header() + + info = "" + info += "MAGIC NUMBER: {0}".format(hdr[Field.MAGIC_NUMBER]) + info += "v.{0}".format(hdr['version']) + info += "dim: {0}".format(hdr[Field.DIMENSIONS]) + info += "voxel_sizes: {0}".format(hdr[Field.VOXEL_SIZES]) + info += "orgin: {0}".format(hdr[Field.ORIGIN]) + info += "nb_scalars: {0}".format(hdr[Field.NB_SCALARS_PER_POINT]) + info += "scalar_name:\n {0}".format("\n".join(hdr['scalar_name'])) + info += "nb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + info += "property_name:\n {0}".format("\n".join(hdr['property_name'])) + info += "vox_to_world: {0}".format(hdr[Field.VOXEL_TO_WORLD]) + info += "world_order: {0}".format(hdr[Field.WORLD_ORDER]) + info += "voxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) + info += "image_orientation_patient: {0}".format(hdr['image_orientation_patient']) + info += "pad1: {0}".format(hdr['pad1']) + info += "pad2: {0}".format(hdr['pad2']) + info += "invert_x: {0}".format(hdr['invert_x']) + info += "invert_y: {0}".format(hdr['invert_y']) + info += "invert_z: {0}".format(hdr['invert_z']) + info += "swap_xy: {0}".format(hdr['swap_xy']) + info += "swap_yz: {0}".format(hdr['swap_yz']) + info += "swap_zx: {0}".format(hdr['swap_zx']) + info += "n_count: {0}".format(hdr[Field.NB_STREAMLINES]) + info += "hdr_size: {0}".format(hdr['hdr_size']) + info += "endianess: {0}".format(hdr[Field.ENDIAN]) + + return info diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 511f27fe01..bb5515988d 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -1,47 +1,135 @@ import os -import nibabel as nib +from ..externals.six import string_types from nibabel.streamlines.trk import TrkFile #from nibabel.streamlines.tck import TckFile #from nibabel.streamlines.vtk import VtkFile -#from nibabel.streamlines.fib import FibFile -# Supported format -FORMATS = {"trk": TrkFile, - #"tck": TckFile, - #"vtk": VtkFile, - #"fib": FibFile, +# List of all supported formats +FORMATS = {".trk": TrkFile, + #".tck": TckFile, + #".vtk": VtkFile, } + def is_supported(fileobj): + ''' Checks if the file-like object if supported by NiBabel. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header) + + Returns + ------- + is_supported : boolean + ''' return detect_format(fileobj) is not None def detect_format(fileobj): + ''' Returns the StreamlinesFile object guessed from the file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header) + + Returns + ------- + streamlines_file : StreamlinesFile object + Object that can be used to manage a streamlines file. + See 'nibabel.streamlines.StreamlinesFile'. + ''' for format in FORMATS.values(): - if format.is_correct_format(fileobj): - return format + try: + if format.is_correct_format(fileobj): + return format + + except IOError: + pass - if isinstance(fileobj, basestring): + if isinstance(fileobj, string_types): _, ext = os.path.splitext(fileobj) return FORMATS.get(ext, None) return None +def load(fileobj, lazy_load=True, anat=None): + ''' Loads streamlines from a file-like object in their native space. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the beginning + of the streamlines file's header) + + Returns + ------- + obj : instance of ``Streamlines`` + Returns an instance of a ``Streamlines`` class containing data and metadata + of streamlines loaded from ``fileobj``. + ''' + + # TODO: Ask everyone what should be the behavior if the anat is provided. + # if anat is None: + # warnings.warn("WARNING: Streamlines will be loaded in their native space (i.e. as they were saved).") + + streamlines_file = detect_format(fileobj) + + if streamlines_file is None: + raise TypeError("Unknown format for 'fileobj': {0}!".format(fileobj)) + + return streamlines_file.load(fileobj, lazy_load=lazy_load) + + +def save(streamlines, filename): + ''' Saves a ``Streamlines`` object to a file + + Parameters + ---------- + streamlines : Streamlines object + Streamlines to be saved (metadata is obtained with the function ``get_header`` of ``streamlines``). + + filename : string + Name of the file where the streamlines will be saved. The format will be guessed from ``filename``. + ''' + + streamlines_file = detect_format(filename) + + if streamlines_file is None: + raise TypeError("Unknown format for 'filename': {0}!".format(filename)) + + streamlines_file.save(streamlines, filename) + + def convert(in_fileobj, out_filename): - in_fileobj = nib.streamlines.load(in_fileobj) - out_format = nib.streamlines.guess_format(out_filename) + ''' Converts one streamlines format to another. + + It does not change the space in which the streamlines are. - hdr = in_fileobj.get_header() - points = in_fileobj.get_points(as_generator=True) - scalars = in_fileobj.get_scalars(as_generator=True) - properties = in_fileobj.get_properties(as_generator=True) + Parameters + ---------- + in_fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header) - out_fileobj = out_format(hdr, points, scalars, properties) - out_fileobj.save(out_filename) + out_filename : string + Name of the file where the streamlines will be saved. The format will + be guessed from ``out_filename``. + ''' + streamlines = load(in_fileobj, lazy_load=True) + save(streamlines, out_filename) +# TODO def change_space(streamline_file, new_point_space): - pass \ No newline at end of file + pass diff --git a/setup.py b/setup.py index d5160b4a0c..ff2a161980 100755 --- a/setup.py +++ b/setup.py @@ -90,6 +90,7 @@ def main(**extra_args): 'nibabel.testing', 'nibabel.tests', 'nibabel.benchmarks', + 'nibabel.streamlines', # install nisext as its own package 'nisext', 'nisext.tests'], @@ -104,6 +105,7 @@ def main(**extra_args): pjoin('externals', 'tests', 'data', '*'), pjoin('nicom', 'tests', 'data', '*'), pjoin('gifti', 'tests', 'data', '*'), + pjoin('streamlines', 'tests', 'data', '*'), ]}, scripts = [pjoin('bin', 'parrec2nii'), pjoin('bin', 'nib-ls'), From 0a6599bb78a3b4b10957352030be060a55ec6cdf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 4 Mar 2015 10:19:22 -0500 Subject: [PATCH 003/135] Added a LazyStreamlines class and made sure streamlines are in voxel space. --- nibabel/streamlines/__init__.py | 2 +- nibabel/streamlines/base_format.py | 158 ++++++++++++-- nibabel/streamlines/header.py | 51 ++++- nibabel/streamlines/tests/test_base_format.py | 205 ++++++++++++++++-- nibabel/streamlines/tests/test_trk.py | 48 ++-- nibabel/streamlines/tests/test_utils.py | 62 +++--- nibabel/streamlines/trk.py | 135 +++++++++--- nibabel/streamlines/utils.py | 37 +++- 8 files changed, 555 insertions(+), 143 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index cd7cd56a21..7f00864c5f 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,7 +1,7 @@ from nibabel.streamlines.utils import load, save -from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import Streamlines, LazyStreamlines from nibabel.streamlines.header import Field from nibabel.streamlines.trk import TrkFile diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 43e3b8616e..30baf4001e 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -1,3 +1,5 @@ +import numpy as np +from warnings import warn from nibabel.streamlines.header import Field @@ -21,45 +23,126 @@ class Streamlines(object): Parameters ---------- - points : sequence of ndarray of shape (N, 3) + points : list of ndarray of shape (N, 3) Sequence of T streamlines. One streamline is an ndarray of shape (N, 3) where N is the number of points in a streamline. - scalars : sequence of ndarray of shape (N, M) + scalars : list of ndarray of shape (N, M) Sequence of T ndarrays of shape (N, M) where T is the number of streamlines defined by ``points``, N is the number of points for a particular streamline and M is the number of scalars associated to each point (excluding the three coordinates). - properties : sequence of ndarray of shape (P,) + properties : list of ndarray of shape (P,) Sequence of T ndarrays of shape (P,) where T is the number of streamlines defined by ``points``, P is the number of properties - associated to each streamlines. + associated to each streamline. hdr : dict Header containing meta information about the streamlines. For a list of common header's fields to use as keys see `nibabel.streamlines.Field`. ''' - def __init__(self, points=[], scalars=[], properties=[], hdr={}): - self.hdr = hdr + def __init__(self, points=[], scalars=[], properties=[]): #, hdr={}): + # Create basic header from given informations. + self._header = {} + self._header[Field.VOXEL_TO_WORLD] = np.eye(4) self.points = points self.scalars = scalars self.properties = properties - self.data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) - try: - self.length = len(points) - except: - if Field.NB_STREAMLINES in hdr: - self.length = hdr[Field.NB_STREAMLINES] - else: - raise HeaderError(("Neither parameter 'points' nor 'hdr' contain information about" - " number of streamlines. Use key '{0}' to set the number of " - "streamlines in 'hdr'.").format(Field.NB_STREAMLINES)) - def get_header(self): - return self.hdr + @property + def header(self): + return self._header + + @property + def points(self): + return self._points + + @points.setter + def points(self, value): + self._points = value + self._header[Field.NB_STREAMLINES] = len(self.points) + + @property + def scalars(self): + return self._scalars + + @scalars.setter + def scalars(self, value): + self._scalars = value + self._header[Field.NB_SCALARS_PER_POINT] = 0 + if len(self.scalars) > 0: + self._header[Field.NB_SCALARS_PER_POINT] = len(self.scalars[0]) + + @property + def properties(self): + return self._properties + + @properties.setter + def properties(self, value): + self._properties = value + self._header[Field.NB_PROPERTIES_PER_STREAMLINE] = 0 + if len(self.properties) > 0: + self._header[Field.NB_PROPERTIES_PER_STREAMLINE] = len(self.properties[0]) + + def __iter__(self): + return zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) + + def __getitem__(self, idx): + pts = self.points[idx] + scalars = [] + if len(self.scalars) > 0: + scalars = self.scalars[idx] + + properties = [] + if len(self.properties) > 0: + properties = self.properties[idx] + + return pts, scalars, properties + + def __len__(self): + return len(self.points) + + +class LazyStreamlines(Streamlines): + ''' Class containing information about streamlines. + + Streamlines objects have three main properties: ``points``, ``scalars`` + and ``properties``. Streamlines objects can be iterate over producing + tuple of ``points``, ``scalars`` and ``properties`` for each streamline. + + Parameters + ---------- + points : sequence of ndarray of shape (N, 3) + Sequence of T streamlines. One streamline is an ndarray of shape (N, 3) + where N is the number of points in a streamline. + + scalars : sequence of ndarray of shape (N, M) + Sequence of T ndarrays of shape (N, M) where T is the number of + streamlines defined by ``points``, N is the number of points + for a particular streamline and M is the number of scalars + associated to each point (excluding the three coordinates). + + properties : sequence of ndarray of shape (P,) + Sequence of T ndarrays of shape (P,) where T is the number of + streamlines defined by ``points``, P is the number of properties + associated to each streamline. + + hdr : dict + Header containing meta information about the streamlines. For a list + of common header's fields to use as keys see `nibabel.streamlines.Field`. + ''' + def __init__(self, points=[], scalars=[], properties=[], data=None, count=None, getitem=None): #, hdr={}): + super(LazyStreamlines, self).__init__(points, scalars, properties) + + self._data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) + if data is not None: + self._data = data if callable(data) else lambda: data + + self._count = count + self._getitem = getitem @property def points(self): @@ -67,7 +150,7 @@ def points(self): @points.setter def points(self, value): - self._points = value if callable(value) else (lambda: value) + self._points = value if callable(value) else lambda: value @property def scalars(self): @@ -85,11 +168,44 @@ def properties(self): def properties(self, value): self._properties = value if callable(value) else lambda: value + def __getitem__(self, idx): + if self._getitem is None: + raise AttributeError('`LazyStreamlines` does not support indexing.') + + return self._getitem(idx) + def __iter__(self): - return self.data() + return self._data() def __len__(self): - return self.length + # If length is unknown, we'll try to get it as rapidely and accurately as possible. + if self._count is None: + # Length might be contained in the header. + if Field.NB_STREAMLINES in self.header: + return self.header[Field.NB_STREAMLINES] + + if callable(self._count): + # Length might be obtained by re-parsing the file (if streamlines come from one). + self._count = self._count() + + if self._count is None: + try: + # Will work if `points` is a finite sequence (e.g. list, ndarray) + self._count = len(self.points) + except: + pass + + if self._count is None: + # As a last resort, count them by iterating through the list of points (while keeping a copy). + warn("Number of streamlines will be determined manually by looping" + " through the streamlines. Note this will consume any" + " generator used to create this `Streamlines`object. If you" + " know the actual number of streamlines, you might want to" + " set `Field.NB_STREAMLINES` of `self.header` beforehand.") + + return sum(1 for _ in self) + + return self._count class StreamlinesFile: diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index aa40d3bb97..357930c012 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -1,4 +1,7 @@ +import numpy as np +import nibabel from nibabel.orientations import aff2axcodes +from nibabel.spatialimages import SpatialImage class Field: @@ -18,12 +21,23 @@ class Field: ORIGIN = "origin" VOXEL_TO_WORLD = "voxel_to_world" VOXEL_ORDER = "voxel_order" - WORLD_ORDER = "world_order" + #WORLD_ORDER = "world_order" ENDIAN = "endian" +def create_header_from_reference(ref): + if type(ref) is np.ndarray: + return create_header_from_affine(ref) + elif isinstance(ref) is SpatialImage: + return create_header_from_nifti(ref) + + # Assume `ref` is a filename: + img = nibabel.load(ref) + return create_header_from_nifti(img) + + def create_header_from_nifti(img): - ''' Creates a common streamlines' header using a nifti image. + ''' Creates a common streamlines' header using a spatial image. Based on the information of the nifti image a dictionnary is created containing the following keys: `Field.ORIGIN`, `Field.DIMENSIONS`, @@ -32,7 +46,7 @@ def create_header_from_nifti(img): Parameters ---------- - img : Nifti1Image object + img : `SpatialImage` object Image containing information about the anatomy where streamlines were created. @@ -51,7 +65,36 @@ def create_header_from_nifti(img): hdr[Field.DIMENSIONS] = img_header.get_data_shape()[:3] hdr[Field.VOXEL_SIZES] = img_header.get_zooms()[:3] hdr[Field.VOXEL_TO_WORLD] = affine - hdr[Field.WORLD_ORDER] = "RAS" # Nifti space + #hdr[Field.WORLD_ORDER] = "RAS" # Nifti space + hdr[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine)) + + return hdr + + +def create_header_from_affine(affine): + ''' Creates a common streamlines' header using an affine matrix. + + Based on the information of the affine matrix a dictionnary is created + containing the following keys: `Field.ORIGIN`, `Field.DIMENSIONS`, + `Field.VOXEL_SIZES`, `Field.VOXEL_TO_WORLD`, `Field.WORLD_ORDER` + and `Field.VOXEL_ORDER`. + + Parameters + ---------- + affine : 2D array (3,3) | 2D array (4,4) + Affine matrix that transforms streamlines from voxel space to world-space. + + Returns + ------- + hdr : dict + Header containing meta information about streamlines. + ''' + hdr = {} + + hdr[Field.ORIGIN] = affine[:3, -1] + hdr[Field.VOXEL_SIZES] = np.sqrt(np.sum(affine[:3, :3]**2, axis=0)) + hdr[Field.VOXEL_TO_WORLD] = affine + #hdr[Field.WORLD_ORDER] = "RAS" # Nifti space hdr[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine)) return hdr diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index 25956bcfed..624b2d23d5 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -4,15 +4,16 @@ from nibabel.testing import assert_arrays_equal from nose.tools import assert_equal, assert_raises +from numpy.testing import assert_array_equal -from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import Streamlines, LazyStreamlines from nibabel.streamlines.base_format import HeaderError from nibabel.streamlines.header import Field DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') -class TestStreamlines(unittest.TestCase): +class TestLazyStreamlines(unittest.TestCase): def setUp(self): self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") @@ -33,21 +34,21 @@ def setUp(self): def test_streamlines_creation_from_arrays(self): # Empty - streamlines = Streamlines() - assert_equal(len(streamlines), 0) + streamlines = LazyStreamlines() - # TODO: Should Streamlines have a default header? It could have - # NB_STREAMLINES, NB_SCALARS_PER_POINT and NB_PROPERTIES_PER_STREAMLINE - # already set. - hdr = streamlines.get_header() - assert_equal(len(hdr), 0) + # LazyStreamlines have a default header when created from arrays: + # NB_STREAMLINES, NB_SCALARS_PER_POINT, NB_PROPERTIES_PER_STREAMLINE + # and VOXEL_TO_WORLD. + hdr = streamlines.header + assert_equal(len(hdr), 1) + assert_equal(len(streamlines), 0) # Check if we can iterate through the streamlines. for point, scalar, prop in streamlines: pass # Only points - streamlines = Streamlines(points=self.points) + streamlines = LazyStreamlines(points=self.points) assert_equal(len(streamlines), len(self.points)) assert_arrays_equal(streamlines.points, self.points) @@ -56,15 +57,14 @@ def test_streamlines_creation_from_arrays(self): pass # Only scalars - streamlines = Streamlines(scalars=self.colors) - assert_equal(len(streamlines), 0) + streamlines = LazyStreamlines(scalars=self.colors) assert_arrays_equal(streamlines.scalars, self.colors) # TODO: is it a faulty behavior? assert_equal(len(list(streamlines)), len(self.colors)) # Points, scalars and properties - streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) + streamlines = LazyStreamlines(self.points, self.colors, self.mean_curvature_torsion) assert_equal(len(streamlines), len(self.points)) assert_arrays_equal(streamlines.points, self.points) assert_arrays_equal(streamlines.scalars, self.colors) @@ -80,11 +80,11 @@ def test_streamlines_creation_from_generators(self): scalars = (x for x in self.colors) properties = (x for x in self.mean_curvature_torsion) - assert_raises(HeaderError, Streamlines, points, scalars, properties) + streamlines = LazyStreamlines(points, scalars, properties) + + # LazyStreamlines object does not support indexing. + assert_raises(AttributeError, streamlines.__getitem__, 0) - hdr = {Field.NB_STREAMLINES: len(self.points)} - streamlines = Streamlines(points, scalars, properties, hdr) - assert_equal(len(streamlines), len(self.points)) assert_arrays_equal(streamlines.points, self.points) assert_arrays_equal(streamlines.scalars, self.colors) assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) @@ -95,17 +95,31 @@ def test_streamlines_creation_from_generators(self): assert_equal(len(list(streamlines.scalars)), 0) assert_equal(len(list(streamlines.properties)), 0) + # Test function len + points = (x for x in self.points) + streamlines = LazyStreamlines(points, scalars, properties) + + # This will consume generator `points`. + # Note this will produce a warning message. + assert_equal(len(streamlines), len(self.points)) + assert_equal(len(streamlines), 0) + + # It will use `Field.NB_STREAMLINES` if it is in the streamlines header + # Note this won't produce a warning message. + streamlines.header[Field.NB_STREAMLINES] = len(self.points) + assert_equal(len(streamlines), len(self.points)) + def test_streamlines_creation_from_functions(self): # Points, scalars and properties points = lambda: (x for x in self.points) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - assert_raises(HeaderError, Streamlines, points, scalars, properties) + streamlines = LazyStreamlines(points, scalars, properties) + + # LazyStreamlines object does not support indexing. + assert_raises(AttributeError, streamlines.__getitem__, 0) - hdr = {Field.NB_STREAMLINES: len(self.points)} - streamlines = Streamlines(points, scalars, properties, hdr) - assert_equal(len(streamlines), len(self.points)) assert_arrays_equal(streamlines.points, self.points) assert_arrays_equal(streamlines.scalars, self.colors) assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) @@ -115,3 +129,152 @@ def test_streamlines_creation_from_functions(self): assert_arrays_equal(streamlines.points, self.points) assert_arrays_equal(streamlines.scalars, self.colors) assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + + # Test function `len` + # Calling `len` will create a new generator each time. + # Note this will produce a warning message. + assert_equal(len(streamlines), len(self.points)) + assert_equal(len(streamlines), len(self.points)) + + # It will use `Field.NB_STREAMLINES` if it is in the streamlines header + # Note this won't produce a warning message. + streamlines.header[Field.NB_STREAMLINES] = len(self.points) + assert_equal(len(streamlines), len(self.points)) + + def test_len(self): + # Points, scalars and properties + points = lambda: (x for x in self.points) + + # Function `len` is computed differently depending on available information. + # When `points` is a list, `len` will use `len(points)`. + streamlines = LazyStreamlines(points=self.points) + assert_equal(len(streamlines), len(self.points)) + + # When `points` is a generator, `len` will iterate through the streamlines + # and consume the generator. + # TODO: check that it has raised a warning message. + streamlines = LazyStreamlines(points=points()) + assert_equal(len(streamlines), len(self.points)) + assert_equal(len(streamlines), 0) + + # When `points` is a callable object that creates a generator, `len` will iterate + # through the streamlines. + # TODO: check that it has raised a warning message. + streamlines = LazyStreamlines(points=points) + assert_equal(len(streamlines), len(self.points)) + assert_equal(len(streamlines), len(self.points)) + + + # No matter what `points` is, if `Field.NB_STREAMLINES` is set in the header + # `len` returns that value. If not and `count` argument is specified, `len` + # will use that information to return a value. + # TODO: check that no warning messages are raised. + for pts in [self.points, points(), points]: + # `Field.NB_STREAMLINES` is set in the header. + streamlines = LazyStreamlines(points=pts) + streamlines.header[Field.NB_STREAMLINES] = 42 + assert_equal(len(streamlines), 42) + + # `count` is an integer. + streamlines = LazyStreamlines(points=pts, count=42) + assert_equal(len(streamlines), 42) + + # `count` argument is a callable object. + nb_calls = [0] + def count(): + nb_calls[0] += 1 + return 42 + + streamlines = LazyStreamlines(points=points, count=count) + assert_equal(len(streamlines), 42) + assert_equal(len(streamlines), 42) + # Check that the callable object is only called once (caching). + assert_equal(nb_calls[0], 1) + + +class TestStreamlines(unittest.TestCase): + + def setUp(self): + self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") + self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") + self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") + + self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11, 2.22], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + + def test_streamlines_creation_from_arrays(self): + # Empty + streamlines = Streamlines() + assert_equal(len(streamlines), 0) + + # Streamlines have a default header: + # NB_STREAMLINES, NB_SCALARS_PER_POINT, NB_PROPERTIES_PER_STREAMLINE + # and VOXEL_TO_WORLD. + hdr = streamlines.header + assert_equal(len(hdr), 4) + + # Check if we can iterate through the streamlines. + for points, scalars, props in streamlines: + pass + + # Only points + streamlines = Streamlines(points=self.points) + assert_equal(len(streamlines), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + + # Check if we can iterate through the streamlines. + for points, scalars, props in streamlines: + pass + + # Only scalars + streamlines = Streamlines(scalars=self.colors) + assert_equal(len(streamlines), 0) + assert_arrays_equal(streamlines.scalars, self.colors) + + # TODO: is it a faulty behavior? + assert_equal(len(list(streamlines)), len(self.colors)) + + # Points, scalars and properties + streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) + assert_equal(len(streamlines), len(self.points)) + assert_arrays_equal(streamlines.points, self.points) + assert_arrays_equal(streamlines.scalars, self.colors) + assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for point, scalar, prop in streamlines: + pass + + # Retrieves streamlines by their index + for i, (points, scalars, props) in enumerate(streamlines): + points_i, scalars_i, props_i = streamlines[i] + assert_array_equal(points_i, points) + assert_array_equal(scalars_i, scalars) + assert_array_equal(props_i, props) + + def test_streamlines_creation_from_generators(self): + # Points, scalars and properties + points = (x for x in self.points) + scalars = (x for x in self.colors) + properties = (x for x in self.mean_curvature_torsion) + + # To create streamlines from generators use LazyStreamlines. + assert_raises(TypeError, Streamlines, points, scalars, properties) + + def test_streamlines_creation_from_functions(self): + # Points, scalars and properties + points = lambda: (x for x in self.points) + scalars = lambda: (x for x in self.colors) + properties = lambda: (x for x in self.mean_curvature_torsion) + + # To create streamlines from functions use LazyStreamlines. + assert_raises(TypeError, Streamlines, points, scalars, properties) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 251be7a61d..04dea5a88e 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -8,7 +8,7 @@ from nose.tools import assert_equal, assert_raises import nibabel as nib -from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import Streamlines, LazyStreamlines from nibabel.streamlines.base_format import DataError, HeaderError from nibabel.streamlines.header import Field from nibabel.streamlines.trk import TrkFile @@ -36,9 +36,9 @@ def setUp(self): np.array([3.11, 3.22], dtype="f4")] def test_load_empty_file(self): - empty_trk = nib.streamlines.load(self.empty_trk_filename, lazy_load=False) + empty_trk = nib.streamlines.load(self.empty_trk_filename, ref=None, lazy_load=False) - hdr = empty_trk.get_header() + hdr = empty_trk.header assert_equal(hdr[Field.NB_STREAMLINES], 0) assert_equal(len(empty_trk), 0) @@ -56,9 +56,9 @@ def test_load_empty_file(self): pass def test_load_simple_file(self): - simple_trk = nib.streamlines.load(self.simple_trk_filename, lazy_load=False) + simple_trk = nib.streamlines.load(self.simple_trk_filename, ref=None, lazy_load=False) - hdr = simple_trk.get_header() + hdr = simple_trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(simple_trk), len(self.points)) @@ -76,9 +76,9 @@ def test_load_simple_file(self): pass # Test lazy_load - simple_trk = nib.streamlines.load(self.simple_trk_filename, lazy_load=True) + simple_trk = nib.streamlines.load(self.simple_trk_filename, ref=None, lazy_load=True) - hdr = simple_trk.get_header() + hdr = simple_trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(simple_trk), len(self.points)) @@ -96,9 +96,9 @@ def test_load_simple_file(self): pass def test_load_complex_file(self): - complex_trk = nib.streamlines.load(self.complex_trk_filename, lazy_load=False) + complex_trk = nib.streamlines.load(self.complex_trk_filename, ref=None, lazy_load=False) - hdr = complex_trk.get_header() + hdr = complex_trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_trk), len(self.points)) @@ -115,9 +115,9 @@ def test_load_complex_file(self): for point, scalar, prop in complex_trk: pass - complex_trk = nib.streamlines.load(self.complex_trk_filename, lazy_load=True) + complex_trk = nib.streamlines.load(self.complex_trk_filename, ref=None, lazy_load=True) - hdr = complex_trk.get_header() + hdr = complex_trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_trk), len(self.points)) @@ -142,9 +142,9 @@ def test_write_simple_file(self): simple_trk_file.seek(0, os.SEEK_SET) - simple_trk = nib.streamlines.load(simple_trk_file, lazy_load=False) + simple_trk = nib.streamlines.load(simple_trk_file, ref=None, lazy_load=False) - hdr = simple_trk.get_header() + hdr = simple_trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(simple_trk), len(self.points)) @@ -170,9 +170,9 @@ def test_write_complex_file(self): trk_file.seek(0, os.SEEK_SET) - trk = nib.streamlines.load(trk_file, lazy_load=False) + trk = nib.streamlines.load(trk_file, ref=None, lazy_load=False) - hdr = trk.get_header() + hdr = trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(trk), len(self.points)) @@ -197,9 +197,9 @@ def test_write_complex_file(self): trk_file.seek(0, os.SEEK_SET) - trk = nib.streamlines.load(trk_file, lazy_load=False) + trk = nib.streamlines.load(trk_file, ref=None, lazy_load=False) - hdr = trk.get_header() + hdr = trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(trk), len(self.points)) @@ -224,9 +224,9 @@ def test_write_complex_file(self): trk_file.seek(0, os.SEEK_SET) - trk = nib.streamlines.load(trk_file, lazy_load=False) + trk = nib.streamlines.load(trk_file, ref=None, lazy_load=False) - hdr = trk.get_header() + hdr = trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(trk), len(self.points)) @@ -293,19 +293,17 @@ def test_write_file_from_generator(self): gen_scalars = (scalar for scalar in self.colors) gen_properties = (prop for prop in self.mean_curvature_torsion) - assert_raises(HeaderError, Streamlines, points=gen_points, scalars=gen_scalars, properties=gen_properties) - - hdr = {Field.NB_STREAMLINES: len(self.points)} - streamlines = Streamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties, hdr=hdr) + streamlines = LazyStreamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties) + #streamlines.header[Field.NB_STREAMLINES] = len(self.points) trk_file = BytesIO() TrkFile.save(streamlines, trk_file) trk_file.seek(0, os.SEEK_SET) - trk = nib.streamlines.load(trk_file, lazy_load=False) + trk = nib.streamlines.load(trk_file, ref=None, lazy_load=False) - hdr = trk.get_header() + hdr = trk.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(trk), len(self.points)) diff --git a/nibabel/streamlines/tests/test_utils.py b/nibabel/streamlines/tests/test_utils.py index a0dbcd8b61..1e6ae83b3d 100644 --- a/nibabel/streamlines/tests/test_utils.py +++ b/nibabel/streamlines/tests/test_utils.py @@ -12,7 +12,7 @@ import nibabel.streamlines.utils as streamline_utils -from nibabel.streamlines.base_format import Streamlines +from nibabel.streamlines.base_format import Streamlines, LazyStreamlines from nibabel.streamlines.base_format import HeaderError from nibabel.streamlines.header import Field @@ -116,9 +116,9 @@ def setUp(self): def test_load_empty_file(self): for empty_filename in self.empty_filenames: - empty_streamlines = streamline_utils.load(empty_filename, lazy_load=False) + empty_streamlines = streamline_utils.load(empty_filename, None, lazy_load=False) - hdr = empty_streamlines.get_header() + hdr = empty_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], 0) assert_equal(len(empty_streamlines), 0) @@ -139,9 +139,10 @@ def test_load_empty_file(self): def test_load_simple_file(self): for simple_filename in self.simple_filenames: - simple_streamlines = streamline_utils.load(simple_filename, lazy_load=False) + simple_streamlines = streamline_utils.load(simple_filename, None, lazy_load=False) + assert_true(type(simple_streamlines), Streamlines) - hdr = simple_streamlines.get_header() + hdr = simple_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(simple_streamlines), len(self.points)) @@ -156,11 +157,12 @@ def test_load_simple_file(self): pass # Test lazy_load - simple_streamlines = streamline_utils.load(simple_filename, lazy_load=True) + simple_streamlines = streamline_utils.load(simple_filename, None, lazy_load=True) + assert_true(type(simple_streamlines), LazyStreamlines) - hdr = simple_streamlines.get_header() + hdr = simple_streamlines.header + assert_true(Field.NB_STREAMLINES in hdr) assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) - assert_equal(len(simple_streamlines), len(self.points)) points = simple_streamlines.points assert_arrays_equal(points, self.points) @@ -171,9 +173,9 @@ def test_load_simple_file(self): def test_load_complex_file(self): for complex_filename in self.complex_filenames: - complex_streamlines = streamline_utils.load(complex_filename, lazy_load=False) + complex_streamlines = streamline_utils.load(complex_filename, None, lazy_load=False) - hdr = complex_streamlines.get_header() + hdr = complex_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_streamlines), len(self.points)) @@ -184,9 +186,9 @@ def test_load_complex_file(self): for point, scalar, prop in complex_streamlines: pass - complex_streamlines = streamline_utils.load(complex_filename, lazy_load=True) + complex_streamlines = streamline_utils.load(complex_filename, None, lazy_load=True) - hdr = complex_streamlines.get_header() + hdr = complex_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_streamlines), len(self.points)) @@ -203,9 +205,9 @@ def test_save_simple_file(self): streamlines = Streamlines(self.points) streamline_utils.save(streamlines, f.name) - simple_streamlines = streamline_utils.load(f, lazy_load=False) + simple_streamlines = streamline_utils.load(f, None, lazy_load=False) - hdr = simple_streamlines.get_header() + hdr = simple_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(simple_streamlines), len(self.points)) @@ -219,9 +221,9 @@ def test_save_complex_file(self): streamlines = Streamlines(self.points, scalars=self.colors) streamline_utils.save(streamlines, f.name) - complex_streamlines = streamline_utils.load(f, lazy_load=False) + complex_streamlines = streamline_utils.load(f, None, lazy_load=False) - hdr = complex_streamlines.get_header() + hdr = complex_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_streamlines), len(self.points)) @@ -237,9 +239,9 @@ def test_save_complex_file(self): streamlines = Streamlines(self.points, properties=self.mean_curvature_torsion) streamline_utils.save(streamlines, f.name) - complex_streamlines = streamline_utils.load(f, lazy_load=False) + complex_streamlines = streamline_utils.load(f, None, lazy_load=False) - hdr = complex_streamlines.get_header() + hdr = complex_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_streamlines), len(self.points)) @@ -255,9 +257,9 @@ def test_save_complex_file(self): streamlines = Streamlines(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) streamline_utils.save(streamlines, f.name) - complex_streamlines = streamline_utils.load(f, lazy_load=False) + complex_streamlines = streamline_utils.load(f, None, lazy_load=False) - hdr = complex_streamlines.get_header() + hdr = complex_streamlines.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(complex_streamlines), len(self.points)) @@ -276,15 +278,13 @@ def test_save_file_from_generator(self): gen_scalars = (scalar for scalar in self.colors) gen_properties = (prop for prop in self.mean_curvature_torsion) - assert_raises(HeaderError, Streamlines, points=gen_points, scalars=gen_scalars, properties=gen_properties) - - hdr = {Field.NB_STREAMLINES: len(self.points)} - streamlines = Streamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties, hdr=hdr) + streamlines = LazyStreamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties) + #streamlines.hdr[Field.NB_STREAMLINES] = len(self.points) streamline_utils.save(streamlines, f.name) - streamlines_loaded = streamline_utils.load(f, lazy_load=False) + streamlines_loaded = streamline_utils.load(f, None, lazy_load=False) - hdr = streamlines_loaded.get_header() + hdr = streamlines_loaded.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(streamlines_loaded), len(self.points)) @@ -303,15 +303,13 @@ def test_save_file_from_function(self): gen_scalars = lambda: (scalar for scalar in self.colors) gen_properties = lambda: (prop for prop in self.mean_curvature_torsion) - assert_raises(HeaderError, Streamlines, points=gen_points, scalars=gen_scalars, properties=gen_properties) - - hdr = {Field.NB_STREAMLINES: len(self.points)} - streamlines = Streamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties, hdr=hdr) + streamlines = LazyStreamlines(points=gen_points, scalars=gen_scalars, properties=gen_properties) + #streamlines.hdr[Field.NB_STREAMLINES] = len(self.points) streamline_utils.save(streamlines, f.name) - streamlines_loaded = streamline_utils.load(f, lazy_load=False) + streamlines_loaded = streamline_utils.load(f, None, lazy_load=False) - hdr = streamlines_loaded.get_header() + hdr = streamlines_loaded.header assert_equal(hdr[Field.NB_STREAMLINES], len(self.points)) assert_equal(len(streamlines_loaded), len(self.points)) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index c256bcc0ec..28e27896e1 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -13,7 +13,7 @@ from nibabel.openers import Opener from nibabel.volumeutils import (native_code, swapped_code) -from nibabel.streamlines.base_format import Streamlines, StreamlinesFile +from nibabel.streamlines.base_format import Streamlines, LazyStreamlines, StreamlinesFile from nibabel.streamlines.header import Field from nibabel.streamlines.base_format import DataError, HeaderError @@ -144,7 +144,7 @@ def sanity_check(cls, fileobj): pass # Nothing more to do here else: warnings.warn("NiBabel only supports versions 1 and 2 (not v.{0}).".format(hdr_rec['version'])) - os.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. return False # Convert the first record of `hdr_rec` into a dictionnary @@ -157,7 +157,7 @@ def sanity_check(cls, fileobj): hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0].newbyteorder())) # Swap byte order if hdr['hdr_size'] != cls.HEADER_SIZE: warnings.warn("Invalid hdr_size: {0} instead of {1}".format(hdr['hdr_size'], cls.HEADER_SIZE)) - os.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. return False # By default, the voxel order is LPS. @@ -166,7 +166,6 @@ def sanity_check(cls, fileobj): is_consistent = False warnings.warn("Voxel order is not specified, will assume 'LPS' since it is Trackvis software's default.") - # Add more header fields implied by trk format. i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") @@ -198,12 +197,12 @@ def sanity_check(cls, fileobj): 'the actual number of streamlines contained in this file ({1}). ' ).format(hdr[Field.NB_STREAMLINES], nb_streamlines)) - os.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. return is_consistent @classmethod - def load(cls, fileobj, lazy_load=True): + def load(cls, fileobj, hdr={}, lazy_load=False): ''' Loads streamlines from a file-like object. Parameters @@ -213,9 +212,11 @@ def load(cls, fileobj, lazy_load=True): pointing to TRK file (and ready to read from the beginning of the TRK header) - lazy_load : boolean + hdr : dict (optional) + + lazy_load : boolean (optional) Load streamlines in a lazy manner i.e. they will not be kept - in memory. For postprocessing speed, turn off this option. + in memory. Returns ------- @@ -223,11 +224,7 @@ def load(cls, fileobj, lazy_load=True): Returns an object containing streamlines' data and header information. See 'nibabel.Streamlines'. ''' - hdr = {} - with Opener(fileobj) as f: - hdr['pos_header'] = f.tell() - # Read header hdr_str = f.read(header_2_dtype.itemsize) hdr_rec = np.fromstring(string=hdr_str, dtype=header_2_dtype) @@ -235,12 +232,12 @@ def load(cls, fileobj, lazy_load=True): if hdr_rec['version'] == 1: hdr_rec = np.fromstring(string=hdr_str, dtype=header_1_dtype) elif hdr_rec['version'] == 2: - pass # Nothing more to do here + pass # Nothing more to do else: raise HeaderError('NiBabel only supports versions 1 and 2.') # Convert the first record of `hdr_rec` into a dictionnary - hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0])) + hdr.update(dict(zip(hdr_rec.dtype.names, hdr_rec[0]))) # Check endianness hdr[Field.ENDIAN] = native_code @@ -255,21 +252,30 @@ def load(cls, fileobj, lazy_load=True): if hdr[Field.VOXEL_ORDER] == "": hdr[Field.VOXEL_ORDER] = "LPS" - # Add more header fields implied by trk format. - hdr[Field.WORLD_ORDER] = "RAS" + # Keep the file position where the data begin. hdr['pos_data'] = f.tell() + # If 'count' field is 0, i.e. not provided, we have to loop until the EOF. + if hdr[Field.NB_STREAMLINES] == 0: + del hdr[Field.NB_STREAMLINES] + points = lambda: (x[0] for x in TrkFile._read_data(hdr, fileobj)) scalars = lambda: (x[1] for x in TrkFile._read_data(hdr, fileobj)) properties = lambda: (x[2] for x in TrkFile._read_data(hdr, fileobj)) data = lambda: TrkFile._read_data(hdr, fileobj) if lazy_load: - streamlines = Streamlines(points, scalars, properties, hdr=hdr) - streamlines.data = data - return streamlines + count = TrkFile._count(hdr, fileobj) + if Field.NB_STREAMLINES in hdr: + count = hdr[Field.NB_STREAMLINES] + + streamlines = LazyStreamlines(points, scalars, properties, data=data, count=count) + else: + streamlines = Streamlines(*zip(*data())) - return Streamlines(*zip(*data()), hdr=hdr) + # Set available header's information + streamlines.header.update(hdr) + return streamlines @staticmethod def _read_data(hdr, fileobj): @@ -278,6 +284,8 @@ def _read_data(hdr, fileobj): f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") with Opener(fileobj) as f: + start_position = f.tell() + nb_pts_and_scalars = 3 + int(hdr[Field.NB_SCALARS_PER_POINT]) pts_and_scalars_size = nb_pts_and_scalars * f4_dtype.itemsize @@ -298,20 +306,79 @@ def _read_data(hdr, fileobj): dtype=f4_dtype, count=hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + # Set the file position at the beginning of the data. f.seek(hdr['pos_data'], os.SEEK_SET) - for i in xrange(hdr[Field.NB_STREAMLINES]): + #for i in xrange(hdr[Field.NB_STREAMLINES]): + nb_streamlines = hdr.get(Field.NB_STREAMLINES, np.inf) + i = 0 + while i < nb_streamlines: + nb_pts_str = f.read(i4_dtype.itemsize) + + # Check if we reached EOF + if len(nb_pts_str) == 0: + break + # Read number of points of the next streamline. - nb_pts = struct.unpack(i4_dtype.str[:-1], f.read(i4_dtype.itemsize))[0] + nb_pts = struct.unpack(i4_dtype.str[:-1], nb_pts_str)[0] # Read streamline's data pts, scalars = read_pts_and_scalars(nb_pts) properties = read_properties() + + # TRK's streamlines are in 'voxelmm' space, we send them to voxel space. + pts = pts / hdr[Field.VOXEL_SIZES] + yield pts, scalars, properties + i += 1 + + # In case the 'count' field was not provided. + hdr[Field.NB_STREAMLINES] = i + + # Set the file position where it was. + f.seek(start_position, os.SEEK_CUR) + + @staticmethod + def _count(hdr, fileobj): + ''' Count streamlines from a file-like object using a TRK's header. ''' + nb_streamlines = 0 + + with Opener(fileobj) as f: + start_position = f.tell() + + i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") + f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") + + pts_and_scalars_size = (3 + hdr[Field.NB_SCALARS_PER_POINT]) * f4_dtype.itemsize + properties_size = hdr[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize + + # Set the file position at the beginning of the data. + f.seek(hdr['pos_data'], os.SEEK_SET) + + # Count the actual number of streamlines. + while True: + # Read number of points of the streamline + buf = f.read(i4_dtype.itemsize) + + if buf == '': + break # EOF + + nb_pts = struct.unpack(i4_dtype.str[:-1], buf)[0] + bytes_to_skip = nb_pts * pts_and_scalars_size + bytes_to_skip += properties_size + + # Seek to the next streamline in the file. + f.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) + + nb_streamlines += 1 + + f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. + + return nb_streamlines @classmethod - def get_empty_header(cls): - ''' Return an empty TRK's header. ''' + def create_empty_header(cls): + ''' Return an empty TRK compliant header. ''' hdr = np.zeros(1, dtype=header_2_dtype) #Default values @@ -338,16 +405,22 @@ def save(cls, streamlines, fileobj): If string, a filename; otherwise an open file-like object pointing to TRK file (and ready to read from the beginning of the TRK header data) + + hdr : dict (optional) + + Notes + ----- + Streamlines are assumed to be in voxel space. ''' - hdr = cls.get_empty_header() + hdr = cls.create_empty_header() #Override hdr's fields by those contain in `streamlines`'s header - for k, v in streamlines.get_header().items(): + for k, v in streamlines.header.items(): if k in header_2_dtype.fields.keys(): hdr[k] = v # Check which endianess to use to write data. - endianess = streamlines.get_header().get(Field.ENDIAN, native_code) + endianess = streamlines.header.get(Field.ENDIAN, native_code) if endianess == swapped_code: hdr = hdr.newbyteorder() @@ -364,6 +437,7 @@ def save(cls, streamlines, fileobj): # Write header + data of streamlines with Opener(fileobj, mode="wb") as f: pos = f.tell() + # Write header f.write(hdr[0].tostring()) for points, scalars, properties in streamlines: @@ -374,6 +448,9 @@ def save(cls, streamlines, fileobj): scalars = np.array(scalars, dtype=f4_dtype).reshape((len(points), -1)) properties = np.array(properties, dtype=f4_dtype) + # TRK's streamlines need to be in 'voxelmm' space + points = points * hdr[Field.VOXEL_SIZES] + data = struct.pack(i4_dtype.str[:-1], len(points)) data += np.concatenate((points, scalars), axis=1).tostring() data += properties.tostring() @@ -419,7 +496,7 @@ def pretty_print(streamlines): info : string Header's information relevant to the TRK format. ''' - hdr = streamlines.get_header() + hdr = streamlines.header info = "" info += "MAGIC NUMBER: {0}".format(hdr[Field.MAGIC_NUMBER]) @@ -432,7 +509,7 @@ def pretty_print(streamlines): info += "nb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) info += "property_name:\n {0}".format("\n".join(hdr['property_name'])) info += "vox_to_world: {0}".format(hdr[Field.VOXEL_TO_WORLD]) - info += "world_order: {0}".format(hdr[Field.WORLD_ORDER]) + #info += "world_order: {0}".format(hdr[Field.WORLD_ORDER]) info += "voxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) info += "image_orientation_patient: {0}".format(hdr['image_orientation_patient']) info += "pad1: {0}".format(hdr['pad1']) diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index bb5515988d..0afc7bccda 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -2,6 +2,9 @@ from ..externals.six import string_types +from nibabel.streamlines import header +from nibabel.streamlines.base_format import LazyStreamlines + from nibabel.streamlines.trk import TrkFile #from nibabel.streamlines.tck import TckFile #from nibabel.streamlines.vtk import VtkFile @@ -61,8 +64,8 @@ def detect_format(fileobj): return None -def load(fileobj, lazy_load=True, anat=None): - ''' Loads streamlines from a file-like object in their native space. +def load(fileobj, ref, lazy_load=False): + ''' Loads streamlines from a file-like object in voxel space. Parameters ---------- @@ -71,26 +74,32 @@ def load(fileobj, lazy_load=True, anat=None): pointing to a streamlines file (and ready to read from the beginning of the streamlines file's header) + ref : filename | `Nifti1Image` object | 2D array (3,3) | 2D array (4,4) | None + Reference space where streamlines have been created. + + lazy_load : boolean (optional) + Load streamlines in a lazy manner i.e. they will not be kept + in memory. + Returns ------- obj : instance of ``Streamlines`` Returns an instance of a ``Streamlines`` class containing data and metadata of streamlines loaded from ``fileobj``. ''' - - # TODO: Ask everyone what should be the behavior if the anat is provided. - # if anat is None: - # warnings.warn("WARNING: Streamlines will be loaded in their native space (i.e. as they were saved).") - streamlines_file = detect_format(fileobj) if streamlines_file is None: raise TypeError("Unknown format for 'fileobj': {0}!".format(fileobj)) - return streamlines_file.load(fileobj, lazy_load=lazy_load) + hdr = {} + if ref is not None: + hdr = header.create_header_from_reference(ref) + return streamlines_file.load(fileobj, hdr=hdr, lazy_load=lazy_load) -def save(streamlines, filename): + +def save(streamlines, filename, ref=None): ''' Saves a ``Streamlines`` object to a file Parameters @@ -100,13 +109,21 @@ def save(streamlines, filename): filename : string Name of the file where the streamlines will be saved. The format will be guessed from ``filename``. - ''' + ref : filename | `Nifti1Image` object | 2D array (3,3) | 2D array (4,4) | None (optional) + Reference space the streamlines belong to. Default: get ref from `streamlines.header`. + ''' streamlines_file = detect_format(filename) if streamlines_file is None: raise TypeError("Unknown format for 'filename': {0}!".format(filename)) + if ref is not None: + # Create a `LazyStreamlines` from `streamlines` but using the new reference image. + streamlines = LazyStreamlines(data=iter(streamlines)) + streamlines.header.update(streamlines.header) + streamlines.header.update(header.create_header_from_reference(ref)) + streamlines_file.save(streamlines, filename) From c6db2395db47536d79254a418491c444cacd0577 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 4 Mar 2015 14:50:28 -0500 Subject: [PATCH 004/135] Fixed bug in TRK count function --- nibabel/streamlines/tests/test_trk.py | 11 +++++++++++ nibabel/streamlines/trk.py | 10 +++++----- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 04dea5a88e..cb32cb6334 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -134,6 +134,17 @@ def test_load_complex_file(self): for point, scalar, prop in complex_trk: pass + def test_load_file_with_no_count(self): + trk_file = open(self.simple_trk_filename, 'rb').read() + # Simulate a TRK file where count was not provided. + count = np.array(0, dtype="int32").tostring() + trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] + streamlines = nib.streamlines.load(BytesIO(trk_file), ref=None, lazy_load=False) + assert_equal(len(streamlines), len(self.points)) + + streamlines = nib.streamlines.load(BytesIO(trk_file), ref=None, lazy_load=True) + assert_equal(len(streamlines), len(self.points)) + def test_write_simple_file(self): streamlines = Streamlines(self.points) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 28e27896e1..dbeda618ae 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -178,7 +178,7 @@ def sanity_check(cls, fileobj): # Read number of points of the streamline buf = f.read(i4_dtype.itemsize) - if buf == '': + if len(buf) == 0: break # EOF nb_pts = struct.unpack(i4_dtype.str[:-1], buf)[0] @@ -187,7 +187,7 @@ def sanity_check(cls, fileobj): bytes_to_skip += properties_size # Seek to the next streamline in the file. - f.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) + f.seek(bytes_to_skip, os.SEEK_CUR) nb_streamlines += 1 @@ -265,7 +265,7 @@ def load(cls, fileobj, hdr={}, lazy_load=False): data = lambda: TrkFile._read_data(hdr, fileobj) if lazy_load: - count = TrkFile._count(hdr, fileobj) + count = lambda: TrkFile._count(hdr, fileobj) if Field.NB_STREAMLINES in hdr: count = hdr[Field.NB_STREAMLINES] @@ -360,7 +360,7 @@ def _count(hdr, fileobj): # Read number of points of the streamline buf = f.read(i4_dtype.itemsize) - if buf == '': + if len(buf) == 0: break # EOF nb_pts = struct.unpack(i4_dtype.str[:-1], buf)[0] @@ -368,7 +368,7 @@ def _count(hdr, fileobj): bytes_to_skip += properties_size # Seek to the next streamline in the file. - f.seek(bytes_to_skip * f4_dtype.itemsize, os.SEEK_CUR) + f.seek(bytes_to_skip, os.SEEK_CUR) nb_streamlines += 1 From b0ab29c47971b3acf97aff9d94793a83a15a2d89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 4 Mar 2015 23:29:03 -0500 Subject: [PATCH 005/135] Removed unused function check_integrity --- nibabel/streamlines/trk.py | 90 +------------------------------------- 1 file changed, 1 insertion(+), 89 deletions(-) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index dbeda618ae..134beaa9d6 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -112,95 +112,6 @@ def is_correct_format(cls, fileobj): return False - @classmethod - def sanity_check(cls, fileobj): - ''' Check if data is consistent with information contained in the header. - [Might be useful] - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to TRK file (and ready to read from the beginning - of the TRK header data) - - Returns - ------- - is_consistent : boolean - Returns True if data is consistent with header, False otherwise. - ''' - is_consistent = True - - with Opener(fileobj) as f: - start_position = f.tell() - - # Read header - hdr_str = f.read(header_2_dtype.itemsize) - hdr_rec = np.fromstring(string=hdr_str, dtype=header_2_dtype) - - if hdr_rec['version'] == 1: - hdr_rec = np.fromstring(string=hdr_str, dtype=header_1_dtype) - elif hdr_rec['version'] == 2: - pass # Nothing more to do here - else: - warnings.warn("NiBabel only supports versions 1 and 2 (not v.{0}).".format(hdr_rec['version'])) - f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. - return False - - # Convert the first record of `hdr_rec` into a dictionnary - hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0])) - - # Check endianness - hdr[Field.ENDIAN] = native_code - if hdr['hdr_size'] != cls.HEADER_SIZE: - hdr[Field.ENDIAN] = swapped_code - hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0].newbyteorder())) # Swap byte order - if hdr['hdr_size'] != cls.HEADER_SIZE: - warnings.warn("Invalid hdr_size: {0} instead of {1}".format(hdr['hdr_size'], cls.HEADER_SIZE)) - f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. - return False - - # By default, the voxel order is LPS. - # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates - if hdr[Field.VOXEL_ORDER] == "": - is_consistent = False - warnings.warn("Voxel order is not specified, will assume 'LPS' since it is Trackvis software's default.") - - i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") - f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") - - pts_and_scalars_size = (3 + hdr[Field.NB_SCALARS_PER_POINT]) * f4_dtype.itemsize - properties_size = hdr[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize - - #Verify the number of streamlines specified in the header is correct. - nb_streamlines = 0 - while True: - # Read number of points of the streamline - buf = f.read(i4_dtype.itemsize) - - if len(buf) == 0: - break # EOF - - nb_pts = struct.unpack(i4_dtype.str[:-1], buf)[0] - - bytes_to_skip = nb_pts * pts_and_scalars_size - bytes_to_skip += properties_size - - # Seek to the next streamline in the file. - f.seek(bytes_to_skip, os.SEEK_CUR) - - nb_streamlines += 1 - - if hdr[Field.NB_STREAMLINES] != nb_streamlines: - is_consistent = False - warnings.warn(('The number of streamlines specified in header ({1}) does not match ' - 'the actual number of streamlines contained in this file ({1}). ' - ).format(hdr[Field.NB_STREAMLINES], nb_streamlines)) - - f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. - - return is_consistent - @classmethod def load(cls, fileobj, hdr={}, lazy_load=False): ''' Loads streamlines from a file-like object. @@ -250,6 +161,7 @@ def load(cls, fileobj, hdr={}, lazy_load=False): # By default, the voxel order is LPS. # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates if hdr[Field.VOXEL_ORDER] == "": + warnings.warn("Voxel order is not specified, will assume 'LPS' since it is Trackvis software's default.") hdr[Field.VOXEL_ORDER] = "LPS" # Keep the file position where the data begin. From 2b7e06560fd4c07e12379e32c90087e110dc7123 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 16 Apr 2015 23:10:09 -0400 Subject: [PATCH 006/135] Refactored code and tests --- nibabel/streamlines/__init__.py | 140 ++++- nibabel/streamlines/base_format.py | 276 +++++---- nibabel/streamlines/header.py | 163 +++--- nibabel/streamlines/tests/data/complex.trk | Bin 1228 -> 1228 bytes nibabel/streamlines/tests/data/empty.trk | Bin 1000 -> 1000 bytes nibabel/streamlines/tests/data/simple.trk | Bin 1108 -> 1108 bytes nibabel/streamlines/tests/test_base_format.py | 377 +++++++------ nibabel/streamlines/tests/test_header.py | 37 ++ nibabel/streamlines/tests/test_streamlines.py | 229 ++++++++ nibabel/streamlines/tests/test_trk.py | 325 ++++------- nibabel/streamlines/tests/test_utils.py | 332 +---------- nibabel/streamlines/trk.py | 527 ++++++++++-------- nibabel/streamlines/utils.py | 161 +----- 13 files changed, 1342 insertions(+), 1225 deletions(-) create mode 100644 nibabel/streamlines/tests/test_header.py create mode 100644 nibabel/streamlines/tests/test_streamlines.py diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 7f00864c5f..7f3d7dcdf2 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,9 +1,137 @@ +from .header import Field +from .base_format import Streamlines, LazyStreamlines -from nibabel.streamlines.utils import load, save +from nibabel.streamlines.trk import TrkFile +#from nibabel.streamlines.tck import TckFile +#from nibabel.streamlines.vtk import VtkFile -from nibabel.streamlines.base_format import Streamlines, LazyStreamlines -from nibabel.streamlines.header import Field +# List of all supported formats +FORMATS = {".trk": TrkFile, + #".tck": TckFile, + #".vtk": VtkFile, + } -from nibabel.streamlines.trk import TrkFile -#from nibabel.streamlines.trk import TckFile -#from nibabel.streamlines.trk import VtkFile + +def is_supported(fileobj): + ''' Checks if the file-like object if supported by NiBabel. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header) + + Returns + ------- + is_supported : boolean + ''' + return detect_format(fileobj) is not None + + +def detect_format(fileobj): + ''' Returns the StreamlinesFile object guessed from the file-like object. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header) + + Returns + ------- + streamlines_file : StreamlinesFile object + Object that can be used to manage a streamlines file. + See 'nibabel.streamlines.StreamlinesFile'. + ''' + for format in FORMATS.values(): + try: + if format.is_correct_format(fileobj): + return format + + except IOError: + pass + + import os + from ..externals.six import string_types + if isinstance(fileobj, string_types): + _, ext = os.path.splitext(fileobj) + return FORMATS.get(ext, None) + + return None + + +def load(fileobj, ref, lazy_load=False): + ''' Loads streamlines from a file-like object in voxel space. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the beginning + of the streamlines file's header). + + ref : filename | `Nifti1Image` object | 2D array (4,4) + Reference space where streamlines will live in `fileobj`. + + lazy_load : boolean (optional) + Load streamlines in a lazy manner i.e. they will not be kept + in memory. + + Returns + ------- + obj : instance of `Streamlines` + Returns an instance of a `Streamlines` class containing data and metadata + of streamlines loaded from `fileobj`. + ''' + streamlines_file = detect_format(fileobj) + + if streamlines_file is None: + raise TypeError("Unknown format for 'fileobj': {0}!".format(fileobj)) + + return streamlines_file.load(fileobj, ref, lazy_load=lazy_load) + + +def save(streamlines, filename, ref=None): + ''' Saves a `Streamlines` object to a file + + Parameters + ---------- + streamlines : `Streamlines` object + Streamlines to be saved. + + filename : str + Name of the file where the streamlines will be saved. The format will + be guessed from `filename`. + + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines will live in `fileobj`. + ''' + streamlines_file = detect_format(filename) + + if streamlines_file is None: + raise TypeError("Unknown streamlines file format: '{0}'!".format(filename)) + + streamlines_file.save(streamlines, filename, ref) + + +def convert(in_fileobj, out_filename, ref): + ''' Converts a streamlines file to another format. + + Parameters + ---------- + in_fileobj : string or file-like object + If string, a filename; otherwise an open file-like object pointing + to a streamlines file (and ready to read from the beginning of the + header). + + out_filename : str + Name of the file where the streamlines will be saved. The format will + be guessed from `out_filename`. + + ref : filename | `Nifti1Image` object | 2D array (4,4) + Reference space where streamlines live in `fileobj`. + ''' + streamlines = load(in_fileobj, ref, lazy_load=True) + save(streamlines, out_filename) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 30baf4001e..a94f5728ad 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -1,9 +1,18 @@ -import numpy as np from warnings import warn -from nibabel.streamlines.header import Field +from nibabel.externals.six.moves import zip_longest +from nibabel.affines import apply_affine -from ..externals.six.moves import zip_longest +from .header import StreamlinesHeader +from .utils import pop + + +class UsageWarning(Warning): + pass + + +class HeaderWarning(Warning): + pass class HeaderError(Exception): @@ -37,20 +46,12 @@ class Streamlines(object): Sequence of T ndarrays of shape (P,) where T is the number of streamlines defined by ``points``, P is the number of properties associated to each streamline. - - hdr : dict - Header containing meta information about the streamlines. For a list - of common header's fields to use as keys see `nibabel.streamlines.Field`. ''' - def __init__(self, points=[], scalars=[], properties=[]): #, hdr={}): - # Create basic header from given informations. - self._header = {} - self._header[Field.VOXEL_TO_WORLD] = np.eye(4) - - self.points = points - self.scalars = scalars - self.properties = properties - + def __init__(self, points=None, scalars=None, properties=None): + self._header = StreamlinesHeader() + self.points = points + self.scalars = scalars + self.properties = properties @property def header(self): @@ -62,8 +63,8 @@ def points(self): @points.setter def points(self, value): - self._points = value - self._header[Field.NB_STREAMLINES] = len(self.points) + self._points = value if value else [] + self.header.nb_streamlines = len(self.points) @property def scalars(self): @@ -71,10 +72,11 @@ def scalars(self): @scalars.setter def scalars(self, value): - self._scalars = value - self._header[Field.NB_SCALARS_PER_POINT] = 0 - if len(self.scalars) > 0: - self._header[Field.NB_SCALARS_PER_POINT] = len(self.scalars[0]) + self._scalars = value if value else [] + self.header.nb_scalars_per_point = 0 + + if len(self.scalars) > 0 and len(self.scalars[0]) > 0: + self.header.nb_scalars_per_point = len(self.scalars[0][0]) @property def properties(self): @@ -82,10 +84,11 @@ def properties(self): @properties.setter def properties(self, value): - self._properties = value - self._header[Field.NB_PROPERTIES_PER_STREAMLINE] = 0 + self._properties = value if value else [] + self.header.nb_properties_per_streamline = 0 + if len(self.properties) > 0: - self._header[Field.NB_PROPERTIES_PER_STREAMLINE] = len(self.properties[0]) + self.header.nb_properties_per_streamline = len(self.properties[0]) def __iter__(self): return zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) @@ -100,49 +103,89 @@ def __getitem__(self, idx): if len(self.properties) > 0: properties = self.properties[idx] + if type(idx) is slice: + return list(zip_longest(pts, scalars, properties, fillvalue=[])) + return pts, scalars, properties def __len__(self): return len(self.points) + def to_world_space(self, as_generator=False): + affine = self.header.voxel_to_world + new_points = (apply_affine(affine, pts) for pts in self.points) + + if not as_generator: + return list(new_points) + + return new_points + class LazyStreamlines(Streamlines): ''' Class containing information about streamlines. - Streamlines objects have three main properties: ``points``, ``scalars`` - and ``properties``. Streamlines objects can be iterate over producing - tuple of ``points``, ``scalars`` and ``properties`` for each streamline. + Streamlines objects have four main properties: ``header``, ``points``, + ``scalars`` and ``properties``. Streamlines objects are iterable and + produce tuple of ``points``, ``scalars`` and ``properties`` for each + streamline. Parameters ---------- - points : sequence of ndarray of shape (N, 3) - Sequence of T streamlines. One streamline is an ndarray of shape (N, 3) - where N is the number of points in a streamline. - - scalars : sequence of ndarray of shape (N, M) - Sequence of T ndarrays of shape (N, M) where T is the number of - streamlines defined by ``points``, N is the number of points - for a particular streamline and M is the number of scalars - associated to each point (excluding the three coordinates). - - properties : sequence of ndarray of shape (P,) - Sequence of T ndarrays of shape (P,) where T is the number of - streamlines defined by ``points``, P is the number of properties + points_func : coroutine ouputting (N,3) array-like (optional) + Function yielding streamlines' points. One streamline's points is + an array-like of shape (N,3) where N is the number of points in a + streamline. + + scalars_func : coroutine ouputting (N,M) array-like (optional) + Function yielding streamlines' scalars. One streamline's scalars is + an array-like of shape (N,M) where N is the number of points for a + particular streamline and M is the number of scalars associated to + each point (excluding the three coordinates). + + properties_func : coroutine ouputting (P,) array-like (optional) + Function yielding streamlines' properties. One streamline's properties + is an array-like of shape (P,) where P is the number of properties associated to each streamline. - hdr : dict - Header containing meta information about the streamlines. For a list - of common header's fields to use as keys see `nibabel.streamlines.Field`. - ''' - def __init__(self, points=[], scalars=[], properties=[], data=None, count=None, getitem=None): #, hdr={}): - super(LazyStreamlines, self).__init__(points, scalars, properties) + getitem_func : function `idx -> 3-tuples` (optional) + Function returning streamlines (one or a list of 3-tuples) given + an index or a slice (i.e. the __getitem__ function to use). + Notes + ----- + If provided, ``scalars`` and ``properties`` must yield the same number of + values as ``points``. + ''' + def __init__(self, points_func=lambda:[], scalars_func=lambda: [], properties_func=lambda: [], getitem_func=None): + super(LazyStreamlines, self).__init__(points_func, scalars_func, properties_func) self._data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) - if data is not None: - self._data = data if callable(data) else lambda: data + self._getitem = getitem_func - self._count = count - self._getitem = getitem + @classmethod + def create_from_data(cls, data_func): + ''' Saves streamlines to a file-like object. + + Parameters + ---------- + data_func : coroutine ouputting tuple (optional) + Function yielding 3-tuples, (streamline's points, streamline's + scalars, streamline's properties). A streamline's points is an + array-like of shape (N,3), a streamline's scalars is an array-like + of shape (N,M) and streamline's properties is an array-like of + shape (P,) where N is the number of points for a particular + streamline, M is the number of scalars associated to each point + (excluding the three coordinates) and P is the number of properties + associated to each streamline. + ''' + if not callable(data_func): + raise TypeError("`data` must be a coroutine.") + + lazy_streamlines = cls() + lazy_streamlines._data = data_func + lazy_streamlines.points = lambda: (x[0] for x in data_func()) + lazy_streamlines.scalars = lambda: (x[1] for x in data_func()) + lazy_streamlines.properties = lambda: (x[2] for x in data_func()) + return lazy_streamlines @property def points(self): @@ -150,7 +193,10 @@ def points(self): @points.setter def points(self, value): - self._points = value if callable(value) else lambda: value + if not callable(value): + raise TypeError("`points` must be a coroutine.") + + self._points = value @property def scalars(self): @@ -158,7 +204,14 @@ def scalars(self): @scalars.setter def scalars(self, value): - self._scalars = value if callable(value) else lambda: value + if not callable(value): + raise TypeError("`scalars` must be a coroutine.") + + self._scalars = value + self.header.nb_scalars_per_point = 0 + scalars = pop(self.scalars) + if scalars is not None and len(scalars) > 0: + self.header.nb_scalars_per_point = len(scalars[0]) @property def properties(self): @@ -166,7 +219,14 @@ def properties(self): @properties.setter def properties(self, value): - self._properties = value if callable(value) else lambda: value + if not callable(value): + raise TypeError("`properties` must be a coroutine.") + + self._properties = value + self.header.nb_properties_per_streamline = 0 + properties = pop(self.properties) + if properties is not None: + self.header.nb_properties_per_streamline = len(properties) def __getitem__(self, idx): if self._getitem is None: @@ -175,37 +235,28 @@ def __getitem__(self, idx): return self._getitem(idx) def __iter__(self): - return self._data() + i = 0 + for i, s in enumerate(self._data(), start=1): + yield s + + # To be safe, update information about number of streamlines. + self.header.nb_streamlines = i def __len__(self): - # If length is unknown, we'll try to get it as rapidely and accurately as possible. - if self._count is None: - # Length might be contained in the header. - if Field.NB_STREAMLINES in self.header: - return self.header[Field.NB_STREAMLINES] - - if callable(self._count): - # Length might be obtained by re-parsing the file (if streamlines come from one). - self._count = self._count() - - if self._count is None: - try: - # Will work if `points` is a finite sequence (e.g. list, ndarray) - self._count = len(self.points) - except: - pass - - if self._count is None: - # As a last resort, count them by iterating through the list of points (while keeping a copy). + # If length is unknown, we obtain it by iterating through streamlines. + if self.header.nb_streamlines is None: warn("Number of streamlines will be determined manually by looping" - " through the streamlines. Note this will consume any" - " generator used to create this `Streamlines`object. If you" - " know the actual number of streamlines, you might want to" - " set `Field.NB_STREAMLINES` of `self.header` beforehand.") - + " through the streamlines. If you know the actual number of" + " streamlines, you might want to set it beforehand via" + " `self.header.nb_streamlines`." + " Note this will consume any generators used to create this" + " `LazyStreamlines` object.", UsageWarning) return sum(1 for _ in self) - return self._count + return self.header.nb_streamlines + + def to_world_space(self): + return super(LazyStreamlines, self).to_world_space(as_generator=True) class StreamlinesFile: @@ -213,19 +264,29 @@ class StreamlinesFile: @classmethod def get_magic_number(cls): - ''' Return streamlines file's magic number. ''' + ''' Returns streamlines file's magic number. ''' + raise NotImplementedError() + + @classmethod + def can_save_scalars(cls): + ''' Tells if the streamlines format supports saving scalars. ''' + raise NotImplementedError() + + @classmethod + def can_save_properties(cls): + ''' Tells if the streamlines format supports saving properties. ''' raise NotImplementedError() @classmethod def is_correct_format(cls, fileobj): - ''' Check if the file has the right streamlines file format. + ''' Checks if the file has the right streamlines file format. Parameters ---------- fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the - beginning of the header) + beginning of the header). Returns ------- @@ -234,13 +295,8 @@ def is_correct_format(cls, fileobj): ''' raise NotImplementedError() - @classmethod - def get_empty_header(cls): - ''' Return an empty streamlines file's header. ''' - raise NotImplementedError() - - @classmethod - def load(cls, fileobj, lazy_load=True): + @staticmethod + def load(fileobj, ref, lazy_load=True): ''' Loads streamlines from a file-like object. Parameters @@ -248,7 +304,10 @@ def load(cls, fileobj, lazy_load=True): fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the - beginning of the header) + beginning of the header). + + ref : filename | `Nifti1Image` object | 2D array (4,4) + Reference space where streamlines live in `fileobj`. lazy_load : boolean Load streamlines in a lazy manner i.e. they will not be kept @@ -262,8 +321,8 @@ def load(cls, fileobj, lazy_load=True): ''' raise NotImplementedError() - @classmethod - def save(cls, streamlines, fileobj): + @staticmethod + def save(streamlines, fileobj, ref=None): ''' Saves streamlines to a file-like object. Parameters @@ -275,35 +334,38 @@ def save(cls, streamlines, fileobj): fileobj : string or file-like object If string, a filename; otherwise an open file-like object opened and ready to write. + + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines will live in `fileobj`. ''' raise NotImplementedError() @staticmethod def pretty_print(streamlines): - ''' Gets a formatted string contaning header's information - relevant to the streamlines file format. + ''' Gets a formatted string of the header of a streamlines file format. Parameters ---------- - streamlines : Streamlines object - Object containing streamlines' data and header information. - See 'nibabel.Streamlines'. + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header). Returns ------- info : string - Header's information relevant to the streamlines file format. + Header information relevant to the streamlines file format. ''' raise NotImplementedError() -class DynamicStreamlineFile(StreamlinesFile): - ''' Convenience class to encapsulate streamlines file format - that supports appending streamlines to an existing file. - ''' +# class DynamicStreamlineFile(StreamlinesFile): +# ''' Convenience class to encapsulate streamlines file format +# that supports appending streamlines to an existing file. +# ''' - def append(self, streamlines): - raise NotImplementedError() +# def append(self, streamlines): +# raise NotImplementedError() - def __iadd__(self, streamlines): - return self.append(streamlines) +# def __iadd__(self, streamlines): +# return self.append(streamlines) diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 357930c012..46e4af3a94 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -1,7 +1,6 @@ import numpy as np -import nibabel from nibabel.orientations import aff2axcodes -from nibabel.spatialimages import SpatialImage +from collections import OrderedDict class Field: @@ -21,80 +20,92 @@ class Field: ORIGIN = "origin" VOXEL_TO_WORLD = "voxel_to_world" VOXEL_ORDER = "voxel_order" - #WORLD_ORDER = "world_order" ENDIAN = "endian" -def create_header_from_reference(ref): - if type(ref) is np.ndarray: - return create_header_from_affine(ref) - elif isinstance(ref) is SpatialImage: - return create_header_from_nifti(ref) - - # Assume `ref` is a filename: - img = nibabel.load(ref) - return create_header_from_nifti(img) - - -def create_header_from_nifti(img): - ''' Creates a common streamlines' header using a spatial image. - - Based on the information of the nifti image a dictionnary is created - containing the following keys: `Field.ORIGIN`, `Field.DIMENSIONS`, - `Field.VOXEL_SIZES`, `Field.VOXEL_TO_WORLD`, `Field.WORLD_ORDER` - and `Field.VOXEL_ORDER`. - - Parameters - ---------- - img : `SpatialImage` object - Image containing information about the anatomy where streamlines - were created. - - Returns - ------- - hdr : dict - Header containing meta information about streamlines extracted - from the anatomy. - ''' - img_header = img.get_header() - affine = img_header.get_best_affine() - - hdr = {} - - hdr[Field.ORIGIN] = affine[:3, -1] - hdr[Field.DIMENSIONS] = img_header.get_data_shape()[:3] - hdr[Field.VOXEL_SIZES] = img_header.get_zooms()[:3] - hdr[Field.VOXEL_TO_WORLD] = affine - #hdr[Field.WORLD_ORDER] = "RAS" # Nifti space - hdr[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine)) - - return hdr - - -def create_header_from_affine(affine): - ''' Creates a common streamlines' header using an affine matrix. - - Based on the information of the affine matrix a dictionnary is created - containing the following keys: `Field.ORIGIN`, `Field.DIMENSIONS`, - `Field.VOXEL_SIZES`, `Field.VOXEL_TO_WORLD`, `Field.WORLD_ORDER` - and `Field.VOXEL_ORDER`. - - Parameters - ---------- - affine : 2D array (3,3) | 2D array (4,4) - Affine matrix that transforms streamlines from voxel space to world-space. - - Returns - ------- - hdr : dict - Header containing meta information about streamlines. - ''' - hdr = {} - - hdr[Field.ORIGIN] = affine[:3, -1] - hdr[Field.VOXEL_SIZES] = np.sqrt(np.sum(affine[:3, :3]**2, axis=0)) - hdr[Field.VOXEL_TO_WORLD] = affine - #hdr[Field.WORLD_ORDER] = "RAS" # Nifti space - hdr[Field.VOXEL_ORDER] = "".join(aff2axcodes(affine)) - - return hdr +class StreamlinesHeader(object): + def __init__(self): + self._nb_streamlines = None + self._nb_scalars_per_point = None + self._nb_properties_per_streamline = None + self._voxel_to_world = np.eye(4) + self.extra = OrderedDict() + + @property + def voxel_to_world(self): + return self._voxel_to_world + + @voxel_to_world.setter + def voxel_to_world(self, value): + self._voxel_to_world = np.array(value, dtype=np.float32) + + @property + def voxel_sizes(self): + """ Get voxel sizes from voxel_to_world. """ + return np.sqrt(np.sum(self.voxel_to_world[:3, :3]**2, axis=0)) + + @voxel_sizes.setter + def voxel_sizes(self, value): + scaling = np.r_[np.array(value), [1]] + old_scaling = np.r_[np.array(self.voxel_sizes), [1]] + # Remove old scaling and apply new one + self.voxel_to_world = np.dot(np.diag(scaling/old_scaling), self.voxel_to_world) + + @property + def voxel_order(self): + """ Get voxel order from voxel_to_world. """ + return "".join(aff2axcodes(self.voxel_to_world)) + + @property + def nb_streamlines(self): + return self._nb_streamlines + + @nb_streamlines.setter + def nb_streamlines(self, value): + self._nb_streamlines = int(value) + + @property + def nb_scalars_per_point(self): + return self._nb_scalars_per_point + + @nb_scalars_per_point.setter + def nb_scalars_per_point(self, value): + self._nb_scalars_per_point = int(value) + + @property + def nb_properties_per_streamline(self): + return self._nb_properties_per_streamline + + @nb_properties_per_streamline.setter + def nb_properties_per_streamline(self, value): + self._nb_properties_per_streamline = int(value) + + @property + def extra(self): + return self._extra + + @extra.setter + def extra(self, value): + self._extra = OrderedDict(value) + + def __eq__(self, other): + return (np.allclose(self.voxel_to_world, other.voxel_to_world) and + self.nb_streamlines == other.nb_streamlines and + self.nb_scalars_per_point == other.nb_scalars_per_point and + self.nb_properties_per_streamline == other.nb_properties_per_streamline and + repr(self.extra) == repr(other.extra)) # Not the robust way, but will do! + + def __repr__(self): + txt = "Header{\n" + txt += "nb_streamlines: " + repr(self.nb_streamlines) + '\n' + txt += "nb_scalars_per_point: " + repr(self.nb_scalars_per_point) + '\n' + txt += "nb_properties_per_streamline: " + repr(self.nb_properties_per_streamline) + '\n' + txt += "voxel_to_world: " + repr(self.voxel_to_world) + '\n' + txt += "voxel_sizes: " + repr(self.voxel_sizes) + '\n' + + txt += "Extra fields: {\n" + for key in sorted(self.extra.keys()): + txt += " " + repr(key) + ": " + repr(self.extra[key]) + "\n" + + txt += " }\n" + return txt + "}" diff --git a/nibabel/streamlines/tests/data/complex.trk b/nibabel/streamlines/tests/data/complex.trk index 2a96a7bd35f6ce69f05907a351466ac838de1045..9bfbe5ea60917d6f98920b3a08e460be5fb6732d 100644 GIT binary patch literal 1228 zcmWFua&-1)U<5-3h6Z~CW`F}`IBTdgn1;FsklYW7D|A4K7`j=b{J|OmL5{&hISrI= zK`H(PkiiJi0ZcOofLOr+lFuQ6)gpcN-!$gfLvsg`8jw6JO(g*JOaS5yKnwy9aa%5S zhiVZ%2Udt6nqHVXbiFX~3l0nnAAp#{5hMZuAaf;vd<`JBKoj=>@*{va14Z1H%N^(j GKL-HhO(bCe literal 1228 zcmWFua&-1)U<5-3h6Z~CW`F}`IBTdgn1;FsklYW7D|A4K7`j=b{GlEKpnMBT@h^Z3 zM!09dX~Y4N&mn@(Dj03LHZ8>)ja@W21g)<0+6@>kgov52590AKz;xaC!mPia=8QD;O77UQAr@i diff --git a/nibabel/streamlines/tests/data/empty.trk b/nibabel/streamlines/tests/data/empty.trk index 023b3c5905f222fbd76ce9f82ca72be319b427f6..e78e28403b087f627c298afc88c200dc2d50bcff 100644 GIT binary patch delta 18 acmaFC{(^nO7G~xk$KZ*Inv)+ea{vHEWCn-; delta 14 WcmaFC{(^nO7UqctI+GtTa{vG^e+97s diff --git a/nibabel/streamlines/tests/data/simple.trk b/nibabel/streamlines/tests/data/simple.trk index dc5eff4adc9cd919e7bfe31a37485ecc51348109..df601e29a7cb54b7be756c8a58aa205bc19af70a 100644 GIT binary patch literal 1108 zcmWFua&-1)U<5-3h6Z~CW*7y7Is`y*g$^hYLpN)bKUhN`$T65Gr!fOnF#+)lAcGO2 z1DIwG0I`AtNE`^@dR, P->A, I->S). IF (0 based) value [3, 3] from +# coordinates (axes L->R, P->A, I->S). If (0 based) value [3, 3] from # this matrix is 0, this means the matrix is not recorded. header_2_dtd = [(Field.MAGIC_NUMBER, 'S6'), (Field.DIMENSIONS, 'h', 3), @@ -78,132 +80,78 @@ header_2_dtype = np.dtype(header_2_dtd) -class TrkFile(StreamlinesFile): - ''' Convenience class to encapsulate TRK format. ''' - - MAGIC_NUMBER = b"TRACK" - HEADER_SIZE = 1000 - - @classmethod - def get_magic_number(cls): - ''' Return TRK's magic number. ''' - return cls.MAGIC_NUMBER +class TrkReader(object): + ''' Convenience class to encapsulate TRK file format. - @classmethod - def is_correct_format(cls, fileobj): - ''' Check if the file is in TRK format. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to TRK file (and ready to read from the beginning - of the TRK header data) - - Returns - ------- - is_correct_format : boolean - Returns True if `fileobj` is in TRK format. - ''' - with Opener(fileobj) as f: - magic_number = f.read(5) - f.seek(-5, os.SEEK_CUR) - return magic_number == cls.MAGIC_NUMBER - - return False - - @classmethod - def load(cls, fileobj, hdr={}, lazy_load=False): - ''' Loads streamlines from a file-like object. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to TRK file (and ready to read from the beginning - of the TRK header) + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header) - hdr : dict (optional) + Note + ---- + TrackVis (so its file format: TRK) considers the streamline coordinate + (0,0,0) to be in the corner of the voxel whereas NiBabel's streamlines + internal representation (Voxel space) assume (0,0,0) to be in the + center of the voxel. - lazy_load : boolean (optional) - Load streamlines in a lazy manner i.e. they will not be kept - in memory. + Thus, streamlines are shifted of half a voxel on load and are shifted + back on save. + ''' + def __init__(self, fileobj): + self.fileobj = fileobj - Returns - ------- - streamlines : Streamlines object - Returns an object containing streamlines' data and header - information. See 'nibabel.Streamlines'. - ''' - with Opener(fileobj) as f: + with Opener(self.fileobj) as f: # Read header - hdr_str = f.read(header_2_dtype.itemsize) - hdr_rec = np.fromstring(string=hdr_str, dtype=header_2_dtype) + header_str = f.read(header_2_dtype.itemsize) + header_rec = np.fromstring(string=header_str, dtype=header_2_dtype) - if hdr_rec['version'] == 1: - hdr_rec = np.fromstring(string=hdr_str, dtype=header_1_dtype) - elif hdr_rec['version'] == 2: + if header_rec['version'] == 1: + header_rec = np.fromstring(string=header_str, dtype=header_1_dtype) + elif header_rec['version'] == 2: pass # Nothing more to do else: raise HeaderError('NiBabel only supports versions 1 and 2.') - # Convert the first record of `hdr_rec` into a dictionnary - hdr.update(dict(zip(hdr_rec.dtype.names, hdr_rec[0]))) + # Convert the first record of `header_rec` into a dictionnary + self.header = dict(zip(header_rec.dtype.names, header_rec[0])) # Check endianness - hdr[Field.ENDIAN] = native_code - if hdr['hdr_size'] != cls.HEADER_SIZE: - hdr[Field.ENDIAN] = swapped_code - hdr = dict(zip(hdr_rec.dtype.names, hdr_rec[0].newbyteorder())) # Swap byte order - if hdr['hdr_size'] != cls.HEADER_SIZE: - raise HeaderError('Invalid hdr_size: {0} instead of {1}'.format(hdr['hdr_size'], cls.HEADER_SIZE)) + self.endianness = native_code + if self.header['hdr_size'] != TrkFile.HEADER_SIZE: + self.endianness = swapped_code + + # Swap byte order + self.header = dict(zip(header_rec.dtype.names, header_rec[0].newbyteorder())) + if self.header['hdr_size'] != TrkFile.HEADER_SIZE: + raise HeaderError('Invalid hdr_size: {0} instead of {1}'.format(self.header['hdr_size'], TrkFile.HEADER_SIZE)) # By default, the voxel order is LPS. # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates - if hdr[Field.VOXEL_ORDER] == "": - warnings.warn("Voxel order is not specified, will assume 'LPS' since it is Trackvis software's default.") - hdr[Field.VOXEL_ORDER] = "LPS" + if self.header[Field.VOXEL_ORDER] == b"": + warnings.warn(("Voxel order is not specified, will assume" + " 'LPS' since it is Trackvis software's" + " default."), HeaderWarning) + self.header[Field.VOXEL_ORDER] = b"LPS" # Keep the file position where the data begin. - hdr['pos_data'] = f.tell() - - # If 'count' field is 0, i.e. not provided, we have to loop until the EOF. - if hdr[Field.NB_STREAMLINES] == 0: - del hdr[Field.NB_STREAMLINES] + self.offset_data = f.tell() - points = lambda: (x[0] for x in TrkFile._read_data(hdr, fileobj)) - scalars = lambda: (x[1] for x in TrkFile._read_data(hdr, fileobj)) - properties = lambda: (x[2] for x in TrkFile._read_data(hdr, fileobj)) - data = lambda: TrkFile._read_data(hdr, fileobj) + def __iter__(self): + i4_dtype = np.dtype(self.endianness + "i4") + f4_dtype = np.dtype(self.endianness + "f4") - if lazy_load: - count = lambda: TrkFile._count(hdr, fileobj) - if Field.NB_STREAMLINES in hdr: - count = hdr[Field.NB_STREAMLINES] - - streamlines = LazyStreamlines(points, scalars, properties, data=data, count=count) - else: - streamlines = Streamlines(*zip(*data())) - - # Set available header's information - streamlines.header.update(hdr) - return streamlines - - @staticmethod - def _read_data(hdr, fileobj): - ''' Read streamlines' data from a file-like object using a TRK's header. ''' - i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") - f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") - - with Opener(fileobj) as f: + with Opener(self.fileobj) as f: start_position = f.tell() - nb_pts_and_scalars = 3 + int(hdr[Field.NB_SCALARS_PER_POINT]) - pts_and_scalars_size = nb_pts_and_scalars * f4_dtype.itemsize + nb_pts_and_scalars = int(3 + self.header[Field.NB_SCALARS_PER_POINT]) + pts_and_scalars_size = int(nb_pts_and_scalars * f4_dtype.itemsize) slice_pts_and_scalars = lambda data: (data, []) - if hdr[Field.NB_SCALARS_PER_POINT] > 0: - # This is faster than np.split + if self.header[Field.NB_SCALARS_PER_POINT] > 0: + # This is faster than `np.split` slice_pts_and_scalars = lambda data: (data[:, :3], data[:, 3:]) # Using np.fromfile would be faster, but does not support StringIO @@ -211,18 +159,21 @@ def _read_data(hdr, fileobj): dtype=f4_dtype, buffer=f.read(nb_pts * pts_and_scalars_size))) - properties_size = int(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) * f4_dtype.itemsize + properties_size = int(self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize) read_properties = lambda: [] - if hdr[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + if self.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: read_properties = lambda: np.fromstring(f.read(properties_size), dtype=f4_dtype, - count=hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + count=self.header[Field.NB_PROPERTIES_PER_STREAMLINE]) # Set the file position at the beginning of the data. - f.seek(hdr['pos_data'], os.SEEK_SET) + f.seek(self.offset_data, os.SEEK_SET) + + # If 'count' field is 0, i.e. not provided, we have to loop until the EOF. + nb_streamlines = self.header[Field.NB_STREAMLINES] + if nb_streamlines == 0: + nb_streamlines = np.inf - #for i in xrange(hdr[Field.NB_STREAMLINES]): - nb_streamlines = hdr.get(Field.NB_STREAMLINES, np.inf) i = 0 while i < nb_streamlines: nb_pts_str = f.read(i4_dtype.itemsize) @@ -239,163 +190,263 @@ def _read_data(hdr, fileobj): properties = read_properties() # TRK's streamlines are in 'voxelmm' space, we send them to voxel space. - pts = pts / hdr[Field.VOXEL_SIZES] + pts = pts / self.header[Field.VOXEL_SIZES] + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas streamlines returned assume (0,0,0) to be the + # center of the voxel. Thus, streamlines are shifted of half + #a voxel. + pts -= np.array(self.header[Field.VOXEL_SIZES])/2. yield pts, scalars, properties i += 1 # In case the 'count' field was not provided. - hdr[Field.NB_STREAMLINES] = i + self.header[Field.NB_STREAMLINES] = i - # Set the file position where it was. + # Set the file position where it was (in case it was already open). f.seek(start_position, os.SEEK_CUR) - @staticmethod - def _count(hdr, fileobj): - ''' Count streamlines from a file-like object using a TRK's header. ''' - nb_streamlines = 0 - with Opener(fileobj) as f: - start_position = f.tell() +class TrkWriter(object): + @classmethod + def create_empty_header(cls): + ''' Return an empty compliant TRK header. ''' + header = np.zeros(1, dtype=header_2_dtype) - i4_dtype = np.dtype(hdr[Field.ENDIAN] + "i4") - f4_dtype = np.dtype(hdr[Field.ENDIAN] + "f4") + #Default values + header[Field.MAGIC_NUMBER] = TrkFile.MAGIC_NUMBER + header[Field.VOXEL_SIZES] = (1, 1, 1) + header[Field.DIMENSIONS] = (1, 1, 1) + header[Field.VOXEL_TO_WORLD] = np.eye(4) + header['version'] = 2 + header['hdr_size'] = TrkFile.HEADER_SIZE - pts_and_scalars_size = (3 + hdr[Field.NB_SCALARS_PER_POINT]) * f4_dtype.itemsize - properties_size = hdr[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize + return header - # Set the file position at the beginning of the data. - f.seek(hdr['pos_data'], os.SEEK_SET) + def __init__(self, fileobj, header): + self.header = self.create_empty_header() - # Count the actual number of streamlines. - while True: - # Read number of points of the streamline - buf = f.read(i4_dtype.itemsize) + # Override hdr's fields by those contain in `header`. + for k, v in header.extra.items(): + if k in header_2_dtype.fields.keys(): + self.header[k] = v - if len(buf) == 0: - break # EOF + self.header[Field.NB_STREAMLINES] = 0 + if header.nb_streamlines is not None: + self.header[Field.NB_STREAMLINES] = header.nb_streamlines - nb_pts = struct.unpack(i4_dtype.str[:-1], buf)[0] - bytes_to_skip = nb_pts * pts_and_scalars_size - bytes_to_skip += properties_size + self.header[Field.NB_SCALARS_PER_POINT] = header.nb_scalars_per_point + self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = header.nb_properties_per_streamline + self.header[Field.VOXEL_SIZES] = header.voxel_sizes + self.header[Field.VOXEL_TO_WORLD] = header.voxel_to_world + self.header[Field.VOXEL_ORDER] = header.voxel_order - # Seek to the next streamline in the file. - f.seek(bytes_to_skip, os.SEEK_CUR) + # Keep counts for correcting incoherent fields or warn. + self.nb_streamlines = 0 + self.nb_points = 0 + self.nb_scalars = 0 + self.nb_properties = 0 + + # Write header + self.file = Opener(fileobj, mode="wb") + # Keep track of the beginning of the header. + self.beginning = self.file.tell() + self.file.write(self.header[0].tostring()) + + def write(self, streamlines): + i4_dtype = np.dtype("i4") + f4_dtype = np.dtype("f4") + + for points, scalars, properties in streamlines: + if len(scalars) > 0 and len(scalars) != len(points): + raise DataError("Missing scalars for some points!") + + points = np.array(points, dtype=f4_dtype) + scalars = np.array(scalars, dtype=f4_dtype).reshape((len(points), -1)) + properties = np.array(properties, dtype=f4_dtype) + + # TRK's streamlines need to be in 'voxelmm' space + points = points * self.header[Field.VOXEL_SIZES] + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas streamlines passed in parameters assume (0,0,0) + # to be the center of the voxel. Thus, streamlines are shifted of + # half a voxel. + points += np.array(self.header[Field.VOXEL_SIZES])/2. + + data = struct.pack(i4_dtype.str[:-1], len(points)) + data += np.concatenate((points, scalars), axis=1).tostring() + data += properties.tostring() + self.file.write(data) + + self.nb_streamlines += 1 + self.nb_points += len(points) + self.nb_scalars += scalars.size + self.nb_properties += len(properties) + + # Either correct or warn if header and data are incoherent. + #TODO: add a warn option as a function parameter + nb_scalars_per_point = self.nb_scalars / self.nb_points + nb_properties_per_streamline = self.nb_properties / self.nb_streamlines + + # Check for errors + if nb_scalars_per_point != int(nb_scalars_per_point): + raise DataError("Nb. of scalars differs from one point to another!") + + if nb_properties_per_streamline != int(nb_properties_per_streamline): + raise DataError("Nb. of properties differs from one streamline to another!") + + self.header[Field.NB_STREAMLINES] = self.nb_streamlines + self.header[Field.NB_SCALARS_PER_POINT] = nb_scalars_per_point + self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = nb_properties_per_streamline + + # Overwrite header with updated one. + self.file.seek(self.beginning, os.SEEK_SET) + self.file.write(self.header[0].tostring()) - nb_streamlines += 1 - f.seek(start_position, os.SEEK_CUR) # Set the file position where it was. +class TrkFile(StreamlinesFile): + ''' Convenience class to encapsulate TRK file format. - return nb_streamlines + Note + ---- + TrackVis (so its file format: TRK) considers the streamline coordinate + (0,0,0) to be in the corner of the voxel whereas NiBabel's streamlines + internal representation (Voxel space) assume (0,0,0) to be in the + center of the voxel. + + Thus, streamlines are shifted of half a voxel on load and are shifted + back on save. + ''' + + # Contants + MAGIC_NUMBER = b"TRACK" + HEADER_SIZE = 1000 @classmethod - def create_empty_header(cls): - ''' Return an empty TRK compliant header. ''' - hdr = np.zeros(1, dtype=header_2_dtype) + def get_magic_number(cls): + ''' Return TRK's magic number. ''' + return cls.MAGIC_NUMBER - #Default values - hdr[Field.MAGIC_NUMBER] = cls.MAGIC_NUMBER - hdr[Field.VOXEL_SIZES] = (1, 1, 1) - hdr[Field.DIMENSIONS] = (1, 1, 1) - hdr[Field.VOXEL_TO_WORLD] = np.eye(4) - hdr['version'] = 2 - hdr['hdr_size'] = cls.HEADER_SIZE + @classmethod + def can_save_scalars(cls): + ''' Tells if the streamlines format supports saving scalars. ''' + return True - return hdr + @classmethod + def can_save_properties(cls): + ''' Tells if the streamlines format supports saving properties. ''' + return True @classmethod - def save(cls, streamlines, fileobj): - ''' Saves streamlines to a file-like object. + def is_correct_format(cls, fileobj): + ''' Check if the file is in TRK format. Parameters ---------- - streamlines : Streamlines object - Object containing streamlines' data and header information. - See 'nibabel.Streamlines'. - fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to TRK file (and ready to read from the beginning - of the TRK header data) - - hdr : dict (optional) + of the TRK header data). - Notes - ----- - Streamlines are assumed to be in voxel space. + Returns + ------- + is_correct_format : boolean + Returns True if `fileobj` is in TRK format. ''' - hdr = cls.create_empty_header() - - #Override hdr's fields by those contain in `streamlines`'s header - for k, v in streamlines.header.items(): - if k in header_2_dtype.fields.keys(): - hdr[k] = v + with Opener(fileobj) as f: + magic_number = f.read(5) + f.seek(-5, os.SEEK_CUR) + return magic_number == cls.MAGIC_NUMBER - # Check which endianess to use to write data. - endianess = streamlines.header.get(Field.ENDIAN, native_code) + return False - if endianess == swapped_code: - hdr = hdr.newbyteorder() + @staticmethod + def load(fileobj, ref=None, lazy_load=False): + ''' Loads streamlines from a file-like object. - i4_dtype = np.dtype(endianess + "i4") - f4_dtype = np.dtype(endianess + "f4") + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header). - # Keep counts for correcting incoherent fields or warn. - nb_streamlines = 0 - nb_points = 0 - nb_scalars = 0 - nb_properties = 0 + ref : filename | `Nifti1Image` object | 2D array (4,4) | None + Reference space where streamlines live in `fileobj`. - # Write header + data of streamlines - with Opener(fileobj, mode="wb") as f: - pos = f.tell() - # Write header - f.write(hdr[0].tostring()) + lazy_load : boolean (optional) + Load streamlines in a lazy manner i.e. they will not be kept + in memory. - for points, scalars, properties in streamlines: - if len(scalars) > 0 and len(scalars) != len(points): - raise DataError("Missing scalars for some points!") + Returns + ------- + streamlines : Streamlines object + Returns an object containing streamlines' data and header + information. See `nibabel.Streamlines`. - points = np.array(points, dtype=f4_dtype) - scalars = np.array(scalars, dtype=f4_dtype).reshape((len(points), -1)) - properties = np.array(properties, dtype=f4_dtype) + Notes + ----- + Streamlines are assumed to be in voxel space where coordinate (0,0,0) + refers to the center of the voxel. + ''' + trk_reader = TrkReader(fileobj) + + # Check if reference space matches one from TRK's header. + affine = trk_reader.header[Field.VOXEL_TO_WORLD] + if ref is not None: + affine = get_affine_from_reference(ref) + if not np.allclose(affine, trk_reader.header[Field.VOXEL_TO_WORLD]): + raise ValueError("Reference space provided does not match the " + " one from the TRK file header. Use `ref=None`" + " to use one contained in the TRK file") + + #points = lambda: (x[0] for x in trk_reader) + #scalars = lambda: (x[1] for x in trk_reader) + #properties = lambda: (x[2] for x in trk_reader) + data = lambda: iter(trk_reader) - # TRK's streamlines need to be in 'voxelmm' space - points = points * hdr[Field.VOXEL_SIZES] + if lazy_load: + streamlines = LazyStreamlines.create_from_data(data) - data = struct.pack(i4_dtype.str[:-1], len(points)) - data += np.concatenate((points, scalars), axis=1).tostring() - data += properties.tostring() - f.write(data) + # Overwrite scalars and properties if there is none + if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: + streamlines.scalars = lambda: [] + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: + streamlines.properties = lambda: [] + else: + streamlines = Streamlines(*zip(*data())) - nb_streamlines += 1 - nb_points += len(points) - nb_scalars += scalars.size - nb_properties += len(properties) + # Overwrite scalars and properties if there is none + if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: + streamlines.scalars = [] + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: + streamlines.properties = [] - # Either correct or warn if header and data are incoherent. - #TODO: add a warn option as a function parameter - nb_scalars_per_point = nb_scalars / nb_points - nb_properties_per_streamline = nb_properties / nb_streamlines + # Set available common information about streamlines in the header + streamlines.header.voxel_to_world = affine - # Check for errors - if nb_scalars_per_point != int(nb_scalars_per_point): - raise DataError("Nb. of scalars differs from one point to another!") + # If 'count' field is 0, i.e. not provided, we don't set `nb_streamlines` + if trk_reader.header[Field.NB_STREAMLINES] > 0: + streamlines.header.nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] - if nb_properties_per_streamline != int(nb_properties_per_streamline): - raise DataError("Nb. of properties differs from one streamline to another!") + # Keep extra information about TRK format + streamlines.header.extra = trk_reader.header - hdr[Field.NB_STREAMLINES] = nb_streamlines - hdr[Field.NB_SCALARS_PER_POINT] = nb_scalars_per_point - hdr[Field.NB_PROPERTIES_PER_STREAMLINE] = nb_properties_per_streamline + ## Perform some integrity checks + #if trk_reader.header[Field.VOXEL_ORDER] != streamlines.header.voxel_order: + # raise HeaderError("'voxel_order' does not match the affine.") + #if streamlines.header.voxel_sizes != trk_reader.header[Field.VOXEL_SIZES]: + # raise HeaderError("'voxel_sizes' does not match the affine.") + #if streamlines.header.nb_scalars_per_point != trk_reader.header[Field.NB_SCALARS_PER_POINT]: + # raise HeaderError("'nb_scalars_per_point' does not match.") + #if streamlines.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: + # raise HeaderError("'nb_properties_per_streamline' does not match.") - f.seek(pos, os.SEEK_SET) - f.write(hdr[0].tostring()) # Overwrite header with updated one. + return streamlines @staticmethod - def pretty_print(streamlines): - ''' Gets a formatted string contaning header's information - relevant to the TRK format. + def save(streamlines, fileobj, ref=None): + ''' Saves streamlines to a file-like object. Parameters ---------- @@ -403,12 +454,43 @@ def pretty_print(streamlines): Object containing streamlines' data and header information. See 'nibabel.Streamlines'. + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header data). + + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines will live in `fileobj`. + + Notes + ----- + Streamlines are assumed to be in voxel space where coordinate (0,0,0) + refers to the center of the voxel. + ''' + if ref is not None: + streamlines.header.voxel_to_world = get_affine_from_reference(ref) + + trk_writer = TrkWriter(fileobj, streamlines.header) + trk_writer.write(streamlines) + + @staticmethod + def pretty_print(fileobj): + ''' Gets a formatted string of the header of a TRK file. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the header). + Returns ------- info : string - Header's information relevant to the TRK format. + Header information relevant to the TRK format. ''' - hdr = streamlines.header + trk_reader = TrkReader(fileobj) + hdr = trk_reader.header info = "" info += "MAGIC NUMBER: {0}".format(hdr[Field.MAGIC_NUMBER]) @@ -421,7 +503,6 @@ def pretty_print(streamlines): info += "nb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) info += "property_name:\n {0}".format("\n".join(hdr['property_name'])) info += "vox_to_world: {0}".format(hdr[Field.VOXEL_TO_WORLD]) - #info += "world_order: {0}".format(hdr[Field.WORLD_ORDER]) info += "voxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) info += "image_orientation_patient: {0}".format(hdr['image_orientation_patient']) info += "pad1: {0}".format(hdr['pad1']) diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 0afc7bccda..3b5f3ecf56 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -1,151 +1,38 @@ -import os +import numpy as np +import nibabel +import itertools -from ..externals.six import string_types +from nibabel.spatialimages import SpatialImage -from nibabel.streamlines import header -from nibabel.streamlines.base_format import LazyStreamlines -from nibabel.streamlines.trk import TrkFile -#from nibabel.streamlines.tck import TckFile -#from nibabel.streamlines.vtk import VtkFile +def get_affine_from_reference(ref): + """ Returns the affine defining the reference space. -# List of all supported formats -FORMATS = {".trk": TrkFile, - #".tck": TckFile, - #".vtk": VtkFile, - } - - -def is_supported(fileobj): - ''' Checks if the file-like object if supported by NiBabel. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object pointing - to a streamlines file (and ready to read from the beginning of the - header) - - Returns - ------- - is_supported : boolean - ''' - return detect_format(fileobj) is not None - - -def detect_format(fileobj): - ''' Returns the StreamlinesFile object guessed from the file-like object. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object pointing - to a streamlines file (and ready to read from the beginning of the - header) + Parameter + --------- + ref : filename | `Nifti1Image` object | 2D array (4,4) + Reference space where streamlines live in `fileobj`. Returns ------- - streamlines_file : StreamlinesFile object - Object that can be used to manage a streamlines file. - See 'nibabel.streamlines.StreamlinesFile'. - ''' - for format in FORMATS.values(): - try: - if format.is_correct_format(fileobj): - return format - - except IOError: - pass - - if isinstance(fileobj, string_types): - _, ext = os.path.splitext(fileobj) - return FORMATS.get(ext, None) - - return None - - -def load(fileobj, ref, lazy_load=False): - ''' Loads streamlines from a file-like object in voxel space. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the beginning - of the streamlines file's header) - - ref : filename | `Nifti1Image` object | 2D array (3,3) | 2D array (4,4) | None - Reference space where streamlines have been created. - - lazy_load : boolean (optional) - Load streamlines in a lazy manner i.e. they will not be kept - in memory. - - Returns - ------- - obj : instance of ``Streamlines`` - Returns an instance of a ``Streamlines`` class containing data and metadata - of streamlines loaded from ``fileobj``. - ''' - streamlines_file = detect_format(fileobj) - - if streamlines_file is None: - raise TypeError("Unknown format for 'fileobj': {0}!".format(fileobj)) - - hdr = {} - if ref is not None: - hdr = header.create_header_from_reference(ref) - - return streamlines_file.load(fileobj, hdr=hdr, lazy_load=lazy_load) - - -def save(streamlines, filename, ref=None): - ''' Saves a ``Streamlines`` object to a file - - Parameters - ---------- - streamlines : Streamlines object - Streamlines to be saved (metadata is obtained with the function ``get_header`` of ``streamlines``). - - filename : string - Name of the file where the streamlines will be saved. The format will be guessed from ``filename``. - - ref : filename | `Nifti1Image` object | 2D array (3,3) | 2D array (4,4) | None (optional) - Reference space the streamlines belong to. Default: get ref from `streamlines.header`. - ''' - streamlines_file = detect_format(filename) - - if streamlines_file is None: - raise TypeError("Unknown format for 'filename': {0}!".format(filename)) - - if ref is not None: - # Create a `LazyStreamlines` from `streamlines` but using the new reference image. - streamlines = LazyStreamlines(data=iter(streamlines)) - streamlines.header.update(streamlines.header) - streamlines.header.update(header.create_header_from_reference(ref)) - - streamlines_file.save(streamlines, filename) - - -def convert(in_fileobj, out_filename): - ''' Converts one streamlines format to another. + affine : 2D array (4,4) + """ + if type(ref) is np.ndarray: + if ref.shape != (4, 4): + raise ValueError("`ref` needs to be a numpy array with shape (4,4)!") - It does not change the space in which the streamlines are. + return ref + elif isinstance(ref, SpatialImage): + return ref.affine - Parameters - ---------- - in_fileobj : string or file-like object - If string, a filename; otherwise an open file-like object pointing - to a streamlines file (and ready to read from the beginning of the - header) + # Assume `ref` is the name of a neuroimaging file. + return nibabel.load(ref).affine - out_filename : string - Name of the file where the streamlines will be saved. The format will - be guessed from ``out_filename``. - ''' - streamlines = load(in_fileobj, lazy_load=True) - save(streamlines, out_filename) +def pop(iterable): + "Returns the next item from the iterable else None" + value = list(itertools.islice(iterable, 1)) + return value[0] if len(value) > 0 else None # TODO def change_space(streamline_file, new_point_space): From a0f6720897f906d6db9a20560990f226b3fc21f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 17 Apr 2015 00:53:54 -0400 Subject: [PATCH 007/135] Fixed import of OrderedDict and added support to apply transformation on streamlines. --- nibabel/streamlines/base_format.py | 85 +++++++++++++++++-- nibabel/streamlines/header.py | 40 +++++---- nibabel/streamlines/tests/test_base_format.py | 18 ++-- nibabel/streamlines/tests/test_header.py | 12 +-- nibabel/streamlines/tests/test_streamlines.py | 16 ++-- nibabel/streamlines/trk.py | 16 ++-- nibabel/streamlines/utils.py | 4 - 7 files changed, 134 insertions(+), 57 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index a94f5728ad..e5068c69f8 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -1,3 +1,4 @@ +import numpy as np from warnings import warn from nibabel.externals.six.moves import zip_longest @@ -111,14 +112,56 @@ def __getitem__(self, idx): def __len__(self): return len(self.points) - def to_world_space(self, as_generator=False): - affine = self.header.voxel_to_world - new_points = (apply_affine(affine, pts) for pts in self.points) + def copy(self): + """ Returns a copy of this `Streamlines` object. """ + streamlines = Streamlines(self.points, self.scalars, self.properties) + streamlines._header = self.header.copy() + return streamlines - if not as_generator: - return list(new_points) + def transform(self, affine, lazy=False): + """ Applies an affine transformation on the points of each streamline. - return new_points + Parameters + ---------- + affine : 2D array (4,4) + Transformation that will be applied on each streamline. + lazy : bool (optional) + If true output will be a generator of arrays instead of a list. + + Returns + ------- + streamlines + If `lazy` is true, a `LazyStreamlines` object is returned, + otherwise a `Streamlines` object is returned. In both case, + streamlines are in a space defined by `affine`. + """ + points = lambda: (apply_affine(affine, pts) for pts in self.points) + if not lazy: + points = list(points()) + + streamlines = self.copy() + streamlines.points = points + streamlines.header.to_world_space = np.dot(streamlines.header.to_world_space, + np.linalg.inv(affine)) + + return streamlines + + def to_world_space(self, lazy=False): + """ Sends the streamlines back into world space. + + Parameters + ---------- + lazy : bool (optional) + If true output will be a generator of arrays instead of a list. + + Returns + ------- + streamlines + If `lazy` is true, a `LazyStreamlines` object is returned, + otherwise a `Streamlines` object is returned. In both case, + streamlines are in world space. + """ + return self.transform(self.header.to_world_space, lazy) class LazyStreamlines(Streamlines): @@ -255,8 +298,36 @@ def __len__(self): return self.header.nb_streamlines + def copy(self): + """ Returns a copy of this `LazyStreamlines` object. """ + streamlines = LazyStreamlines(self._points, self._scalars, self._properties) + streamlines._header = self.header.copy() + return streamlines + + def transform(self, affine): + """ Applies an affine transformation on the points of each streamline. + + Parameters + ---------- + affine : 2D array (4,4) + Transformation that will be applied on each streamline. + + Returns + ------- + streamlines : `LazyStreamlines` object + Streamlines living in a space defined by `affine`. + """ + return super(LazyStreamlines, self).transform(affine, lazy=True) + def to_world_space(self): - return super(LazyStreamlines, self).to_world_space(as_generator=True) + """ Sends the streamlines back into world space. + + Returns + ------- + streamlines : `LazyStreamlines` object + Streamlines living in world space. + """ + return super(LazyStreamlines, self).to_world_space(lazy=True) class StreamlinesFile: diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 46e4af3a94..7a7ec63b3c 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -1,6 +1,7 @@ +import copy import numpy as np from nibabel.orientations import aff2axcodes -from collections import OrderedDict +from nibabel.externals import OrderedDict class Field: @@ -18,7 +19,7 @@ class Field: DIMENSIONS = "dimensions" MAGIC_NUMBER = "magic_number" ORIGIN = "origin" - VOXEL_TO_WORLD = "voxel_to_world" + to_world_space = "to_world_space" VOXEL_ORDER = "voxel_order" ENDIAN = "endian" @@ -28,33 +29,33 @@ def __init__(self): self._nb_streamlines = None self._nb_scalars_per_point = None self._nb_properties_per_streamline = None - self._voxel_to_world = np.eye(4) + self._to_world_space = np.eye(4) self.extra = OrderedDict() @property - def voxel_to_world(self): - return self._voxel_to_world + def to_world_space(self): + return self._to_world_space - @voxel_to_world.setter - def voxel_to_world(self, value): - self._voxel_to_world = np.array(value, dtype=np.float32) + @to_world_space.setter + def to_world_space(self, value): + self._to_world_space = np.array(value, dtype=np.float32) @property def voxel_sizes(self): - """ Get voxel sizes from voxel_to_world. """ - return np.sqrt(np.sum(self.voxel_to_world[:3, :3]**2, axis=0)) + """ Get voxel sizes from to_world_space. """ + return np.sqrt(np.sum(self.to_world_space[:3, :3]**2, axis=0)) @voxel_sizes.setter def voxel_sizes(self, value): scaling = np.r_[np.array(value), [1]] old_scaling = np.r_[np.array(self.voxel_sizes), [1]] # Remove old scaling and apply new one - self.voxel_to_world = np.dot(np.diag(scaling/old_scaling), self.voxel_to_world) + self.to_world_space = np.dot(np.diag(scaling/old_scaling), self.to_world_space) @property def voxel_order(self): - """ Get voxel order from voxel_to_world. """ - return "".join(aff2axcodes(self.voxel_to_world)) + """ Get voxel order from to_world_space. """ + return "".join(aff2axcodes(self.to_world_space)) @property def nb_streamlines(self): @@ -88,8 +89,17 @@ def extra(self): def extra(self, value): self._extra = OrderedDict(value) + def copy(self): + header = StreamlinesHeader() + header._nb_streamlines = self.nb_streamlines + header.nb_scalars_per_point = self.nb_scalars_per_point + header.nb_properties_per_streamline = self.nb_properties_per_streamline + header.to_world_space = self.to_world_space.copy() + header.extra = copy.deepcopy(self.extra) + return header + def __eq__(self, other): - return (np.allclose(self.voxel_to_world, other.voxel_to_world) and + return (np.allclose(self.to_world_space, other.to_world_space) and self.nb_streamlines == other.nb_streamlines and self.nb_scalars_per_point == other.nb_scalars_per_point and self.nb_properties_per_streamline == other.nb_properties_per_streamline and @@ -100,7 +110,7 @@ def __repr__(self): txt += "nb_streamlines: " + repr(self.nb_streamlines) + '\n' txt += "nb_scalars_per_point: " + repr(self.nb_scalars_per_point) + '\n' txt += "nb_properties_per_streamline: " + repr(self.nb_properties_per_streamline) + '\n' - txt += "voxel_to_world: " + repr(self.voxel_to_world) + '\n' + txt += "to_world_space: " + repr(self.to_world_space) + '\n' txt += "voxel_sizes: " + repr(self.voxel_sizes) + '\n' txt += "Extra fields: {\n" diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index e132ccf576..f2396da411 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -117,8 +117,8 @@ def test_to_world_space(self): # World space is (RAS+) with voxel size of 2x3x4mm. streamlines.header.voxel_sizes = (2, 3, 4) - new_points = streamlines.to_world_space() - for new_pts, pts in zip(new_points, self.points): + new_streamlines = streamlines.to_world_space() + for new_pts, pts in zip(new_streamlines.points, self.points): for dim, size in enumerate(streamlines.header.voxel_sizes): assert_array_almost_equal(new_pts[:, dim], size*pts[:, dim]) @@ -129,7 +129,7 @@ def test_header(self): assert_equal(streamlines.header.nb_scalars_per_point, 0) assert_equal(streamlines.header.nb_properties_per_streamline, 0) assert_array_equal(streamlines.header.voxel_sizes, (1, 1, 1)) - assert_array_equal(streamlines.header.voxel_to_world, np.eye(4)) + assert_array_equal(streamlines.header.to_world_space, np.eye(4)) assert_equal(streamlines.header.extra, {}) streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) @@ -138,15 +138,15 @@ def test_header(self): assert_equal(streamlines.header.nb_scalars_per_point, self.colors[0].shape[1]) assert_equal(streamlines.header.nb_properties_per_streamline, self.mean_curvature_torsion[0].shape[0]) - # Modifying voxel_sizes should be reflected in voxel_to_world + # Modifying voxel_sizes should be reflected in to_world_space streamlines.header.voxel_sizes = (2, 3, 4) assert_array_equal(streamlines.header.voxel_sizes, (2, 3, 4)) - assert_array_equal(np.diag(streamlines.header.voxel_to_world), (2, 3, 4, 1)) + assert_array_equal(np.diag(streamlines.header.to_world_space), (2, 3, 4, 1)) - # Modifying scaling of voxel_to_world should be reflected in voxel_sizes - streamlines.header.voxel_to_world = np.diag([4, 3, 2, 1]) + # Modifying scaling of to_world_space should be reflected in voxel_sizes + streamlines.header.to_world_space = np.diag([4, 3, 2, 1]) assert_array_equal(streamlines.header.voxel_sizes, (4, 3, 2)) - assert_array_equal(streamlines.header.voxel_to_world, np.diag([4, 3, 2, 1])) + assert_array_equal(streamlines.header.to_world_space, np.diag([4, 3, 2, 1])) # Test that we can run __repr__ without error. repr(streamlines.header) @@ -313,7 +313,7 @@ def test_lazy_streamlines_header(self): assert_equal(streamlines.header.nb_scalars_per_point, 0) assert_equal(streamlines.header.nb_properties_per_streamline, 0) assert_array_equal(streamlines.header.voxel_sizes, (1, 1, 1)) - assert_array_equal(streamlines.header.voxel_to_world, np.eye(4)) + assert_array_equal(streamlines.header.to_world_space, np.eye(4)) assert_equal(streamlines.header.extra, {}) points = lambda: (x for x in self.points) diff --git a/nibabel/streamlines/tests/test_header.py b/nibabel/streamlines/tests/test_header.py index 99184bb2ba..a36a257818 100644 --- a/nibabel/streamlines/tests/test_header.py +++ b/nibabel/streamlines/tests/test_header.py @@ -12,7 +12,7 @@ def test_streamlines_header(): assert_true(header.nb_scalars_per_point is None) assert_true(header.nb_properties_per_streamline is None) assert_array_equal(header.voxel_sizes, (1, 1, 1)) - assert_array_equal(header.voxel_to_world, np.eye(4)) + assert_array_equal(header.to_world_space, np.eye(4)) assert_equal(header.extra, {}) # Modify simple attributes @@ -23,15 +23,15 @@ def test_streamlines_header(): assert_equal(header.nb_scalars_per_point, 2) assert_equal(header.nb_properties_per_streamline, 3) - # Modifying voxel_sizes should be reflected in voxel_to_world + # Modifying voxel_sizes should be reflected in to_world_space header.voxel_sizes = (2, 3, 4) assert_array_equal(header.voxel_sizes, (2, 3, 4)) - assert_array_equal(np.diag(header.voxel_to_world), (2, 3, 4, 1)) + assert_array_equal(np.diag(header.to_world_space), (2, 3, 4, 1)) - # Modifying scaling of voxel_to_world should be reflected in voxel_sizes - header.voxel_to_world = np.diag([4, 3, 2, 1]) + # Modifying scaling of to_world_space should be reflected in voxel_sizes + header.to_world_space = np.diag([4, 3, 2, 1]) assert_array_equal(header.voxel_sizes, (4, 3, 2)) - assert_array_equal(header.voxel_to_world, np.diag([4, 3, 2, 1])) + assert_array_equal(header.to_world_space, np.diag([4, 3, 2, 1])) # Test that we can run __repr__ without error. repr(header) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index f137559631..c293ece3b1 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -140,12 +140,12 @@ def setUp(self): self.nb_streamlines = len(self.points) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) - self.voxel_to_world = np.eye(4) + self.to_world_space = np.eye(4) def test_load_empty_file(self): for empty_filename in self.empty_filenames: streamlines = nib.streamlines.load(empty_filename, - ref=self.voxel_to_world, + ref=self.to_world_space, lazy_load=False) assert_true(type(streamlines), Streamlines) check_streamlines(streamlines, 0, [], [], []) @@ -153,7 +153,7 @@ def test_load_empty_file(self): def test_load_simple_file(self): for simple_filename in self.simple_filenames: streamlines = nib.streamlines.load(simple_filename, - ref=self.voxel_to_world, + ref=self.to_world_space, lazy_load=False) assert_true(type(streamlines), Streamlines) check_streamlines(streamlines, self.nb_streamlines, @@ -161,7 +161,7 @@ def test_load_simple_file(self): # Test lazy_load streamlines = nib.streamlines.load(simple_filename, - ref=self.voxel_to_world, + ref=self.to_world_space, lazy_load=True) assert_true(type(streamlines), LazyStreamlines) check_streamlines(streamlines, self.nb_streamlines, @@ -180,7 +180,7 @@ def test_load_complex_file(self): properties = self.mean_curvature_torsion streamlines = nib.streamlines.load(complex_filename, - ref=self.voxel_to_world, + ref=self.to_world_space, lazy_load=False) assert_true(type(streamlines), Streamlines) check_streamlines(streamlines, self.nb_streamlines, @@ -188,7 +188,7 @@ def test_load_complex_file(self): # Test lazy_load streamlines = nib.streamlines.load(complex_filename, - ref=self.voxel_to_world, + ref=self.to_world_space, lazy_load=True) assert_true(type(streamlines), LazyStreamlines) check_streamlines(streamlines, self.nb_streamlines, @@ -199,7 +199,7 @@ def test_save_simple_file(self): for ext in nib.streamlines.FORMATS.keys(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: nib.streamlines.save(streamlines, f.name) - loaded_streamlines = nib.streamlines.load(f, ref=self.voxel_to_world, lazy_load=False) + loaded_streamlines = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) check_streamlines(loaded_streamlines, self.nb_streamlines, self.points, [], []) @@ -224,6 +224,6 @@ def test_save_complex_file(self): if cls.can_save_properties(): properties = self.mean_curvature_torsion - loaded_streamlines = nib.streamlines.load(f, ref=self.voxel_to_world, lazy_load=False) + loaded_streamlines = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) check_streamlines(loaded_streamlines, self.nb_streamlines, self.points, scalars, properties) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 0c9c982d89..2a9ea9d29d 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -58,7 +58,7 @@ ('scalar_name', 'S20', 10), (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), ('property_name', 'S20', 10), - (Field.VOXEL_TO_WORLD, 'f4', (4, 4)), # new field for version 2 + (Field.to_world_space, 'f4', (4, 4)), # new field for version 2 ('reserved', 'S444'), (Field.VOXEL_ORDER, 'S4'), ('pad2', 'S4'), @@ -217,7 +217,7 @@ def create_empty_header(cls): header[Field.MAGIC_NUMBER] = TrkFile.MAGIC_NUMBER header[Field.VOXEL_SIZES] = (1, 1, 1) header[Field.DIMENSIONS] = (1, 1, 1) - header[Field.VOXEL_TO_WORLD] = np.eye(4) + header[Field.to_world_space] = np.eye(4) header['version'] = 2 header['hdr_size'] = TrkFile.HEADER_SIZE @@ -238,7 +238,7 @@ def __init__(self, fileobj, header): self.header[Field.NB_SCALARS_PER_POINT] = header.nb_scalars_per_point self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = header.nb_properties_per_streamline self.header[Field.VOXEL_SIZES] = header.voxel_sizes - self.header[Field.VOXEL_TO_WORLD] = header.voxel_to_world + self.header[Field.to_world_space] = header.to_world_space self.header[Field.VOXEL_ORDER] = header.voxel_order # Keep counts for correcting incoherent fields or warn. @@ -392,10 +392,10 @@ def load(fileobj, ref=None, lazy_load=False): trk_reader = TrkReader(fileobj) # Check if reference space matches one from TRK's header. - affine = trk_reader.header[Field.VOXEL_TO_WORLD] + affine = trk_reader.header[Field.to_world_space] if ref is not None: affine = get_affine_from_reference(ref) - if not np.allclose(affine, trk_reader.header[Field.VOXEL_TO_WORLD]): + if not np.allclose(affine, trk_reader.header[Field.to_world_space]): raise ValueError("Reference space provided does not match the " " one from the TRK file header. Use `ref=None`" " to use one contained in the TRK file") @@ -423,7 +423,7 @@ def load(fileobj, ref=None, lazy_load=False): streamlines.properties = [] # Set available common information about streamlines in the header - streamlines.header.voxel_to_world = affine + streamlines.header.to_world_space = affine # If 'count' field is 0, i.e. not provided, we don't set `nb_streamlines` if trk_reader.header[Field.NB_STREAMLINES] > 0: @@ -468,7 +468,7 @@ def save(streamlines, fileobj, ref=None): refers to the center of the voxel. ''' if ref is not None: - streamlines.header.voxel_to_world = get_affine_from_reference(ref) + streamlines.header.to_world_space = get_affine_from_reference(ref) trk_writer = TrkWriter(fileobj, streamlines.header) trk_writer.write(streamlines) @@ -502,7 +502,7 @@ def pretty_print(fileobj): info += "scalar_name:\n {0}".format("\n".join(hdr['scalar_name'])) info += "nb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) info += "property_name:\n {0}".format("\n".join(hdr['property_name'])) - info += "vox_to_world: {0}".format(hdr[Field.VOXEL_TO_WORLD]) + info += "vox_to_world: {0}".format(hdr[Field.to_world_space]) info += "voxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) info += "image_orientation_patient: {0}".format(hdr['image_orientation_patient']) info += "pad1: {0}".format(hdr['pad1']) diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 3b5f3ecf56..7bbbe1ef8d 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -33,7 +33,3 @@ def pop(iterable): "Returns the next item from the iterable else None" value = list(itertools.islice(iterable, 1)) return value[0] if len(value) > 0 else None - -# TODO -def change_space(streamline_file, new_point_space): - pass From db1c18271db7e065237d8a0767e0f5bc514b6a26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 22 Sep 2015 19:44:21 -0400 Subject: [PATCH 008/135] Added a Streamline class used when indexing/iterating a Streamlines object --- nibabel/streamlines/base_format.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index e5068c69f8..2db3f4e4d6 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -24,6 +24,19 @@ class DataError(Exception): pass +class Streamline(object): + def __init__(self, points, scalars=None, properties=None): + self.points = points + self.scalars = scalars + self.properties = properties + + def __iter__(self): + return iter(self.points) + + def __len__(self): + return len(self.points) + + class Streamlines(object): ''' Class containing information about streamlines. @@ -92,7 +105,8 @@ def properties(self, value): self.header.nb_properties_per_streamline = len(self.properties[0]) def __iter__(self): - return zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) + for data in zip_longest(self.points, self.scalars, self.properties, fillvalue=[]): + yield Streamline(*data) def __getitem__(self, idx): pts = self.points[idx] @@ -107,7 +121,7 @@ def __getitem__(self, idx): if type(idx) is slice: return list(zip_longest(pts, scalars, properties, fillvalue=[])) - return pts, scalars, properties + return Streamline(pts, scalars, properties) def __len__(self): return len(self.points) From db96c2668fa1aa6e5cb6d2b882809f9310fcc659 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 28 Oct 2015 16:46:50 -0400 Subject: [PATCH 009/135] Fixed tests to use clear_and_catch_warnings --- nibabel/streamlines/base_format.py | 18 ++---------------- nibabel/streamlines/tests/test_base_format.py | 8 ++++---- nibabel/streamlines/tests/test_streamlines.py | 4 ++-- nibabel/streamlines/tests/test_trk.py | 6 +++--- nibabel/testing/__init__.py | 3 +++ 5 files changed, 14 insertions(+), 25 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 2db3f4e4d6..e5068c69f8 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -24,19 +24,6 @@ class DataError(Exception): pass -class Streamline(object): - def __init__(self, points, scalars=None, properties=None): - self.points = points - self.scalars = scalars - self.properties = properties - - def __iter__(self): - return iter(self.points) - - def __len__(self): - return len(self.points) - - class Streamlines(object): ''' Class containing information about streamlines. @@ -105,8 +92,7 @@ def properties(self, value): self.header.nb_properties_per_streamline = len(self.properties[0]) def __iter__(self): - for data in zip_longest(self.points, self.scalars, self.properties, fillvalue=[]): - yield Streamline(*data) + return zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) def __getitem__(self, idx): pts = self.points[idx] @@ -121,7 +107,7 @@ def __getitem__(self, idx): if type(idx) is slice: return list(zip_longest(pts, scalars, properties, fillvalue=[])) - return Streamline(pts, scalars, properties) + return pts, scalars, properties def __len__(self): return len(self.points) diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index f2396da411..c8a9465265 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -4,7 +4,7 @@ import warnings from nibabel.testing import assert_arrays_equal -from nibabel.testing import suppress_warnings, catch_warn_reset +from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal, assert_array_almost_equal from nibabel.externals.six.moves import zip @@ -267,7 +267,7 @@ def test_lazy_streamlines_len(self): scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - with catch_warn_reset(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[base_format]) as w: warnings.simplefilter("always") # Always trigger warnings. # Calling `len` will create new generators each time. @@ -286,7 +286,7 @@ def test_lazy_streamlines_len(self): assert_equal(len(streamlines), self.nb_streamlines) assert_equal(len(w), 2) - with catch_warn_reset(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[base_format]) as w: # Once we iterated through the streamlines, we know the length. streamlines = LazyStreamlines(points, scalars, properties) assert_true(streamlines.header.nb_streamlines is None) @@ -298,7 +298,7 @@ def test_lazy_streamlines_len(self): assert_equal(len(streamlines), len(self.points)) assert_equal(len(w), 0) - with catch_warn_reset(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[base_format]) as w: # It first checks if number of streamlines is in the header. streamlines = LazyStreamlines(points, scalars, properties) streamlines.header.nb_streamlines = 1234 diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index c293ece3b1..893331cb77 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -8,7 +8,7 @@ import nibabel as nib from nibabel.externals.six import BytesIO -from nibabel.testing import catch_warn_reset +from nibabel.testing import clear_and_catch_warnings from nibabel.testing import assert_arrays_equal from nose.tools import assert_equal, assert_raises, assert_true, assert_false @@ -207,7 +207,7 @@ def test_save_complex_file(self): streamlines = Streamlines(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: - with catch_warn_reset(record=True, modules=[trk]) as w: + with clear_and_catch_warnings(record=True, modules=[trk]) as w: nib.streamlines.save(streamlines, f.name) # If streamlines format does not support saving scalars or diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 2d47f5c841..1a1aa2ce05 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -4,7 +4,7 @@ from nibabel.externals.six import BytesIO -from nibabel.testing import suppress_warnings, catch_warn_reset +from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nibabel.testing import assert_arrays_equal, assert_streamlines_equal from nose.tools import assert_equal, assert_raises, assert_true @@ -103,7 +103,7 @@ def test_load_file_with_wrong_information(self): check_streamlines(streamlines, self.nb_streamlines, self.points, [], []) streamlines = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) - with catch_warn_reset(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[base_format]) as w: check_streamlines(streamlines, self.nb_streamlines, self.points, [], []) assert_equal(len(w), 1) assert_true(issubclass(w[0].category, UsageWarning)) @@ -111,7 +111,7 @@ def test_load_file_with_wrong_information(self): # Simulate a TRK file where `voxel_order` was not provided. voxel_order = np.zeros(1, dtype="|S3").tostring() new_trk_file = trk_file[:948] + voxel_order + trk_file[948+3:] - with catch_warn_reset(record=True, modules=[trk]) as w: + with clear_and_catch_warnings(record=True, modules=[trk]) as w: TrkFile.load(BytesIO(new_trk_file), ref=None) assert_equal(len(w), 1) assert_true(issubclass(w[0].category, HeaderWarning)) diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index c6b5ebd66b..edf6002394 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -14,8 +14,10 @@ import sys import warnings from os.path import dirname, abspath, join as pjoin +from nibabel.externals.six.moves import zip_longest import numpy as np +from numpy.testing import assert_array_equal from numpy.testing.decorators import skipif # Allow failed import of nose if not now running tests @@ -73,6 +75,7 @@ def assert_re_in(regex, c, flags=0): raise AssertionError("Not a single entry matched %r in %r" % (regex, c)) + def get_fresh_mod(mod_name=__name__): # Get this module, with warning registry empty my_mod = sys.modules[mod_name] From 461d5e4e91f3b7344fcdb6a6b3729928becbabf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 29 Oct 2015 13:45:34 -0400 Subject: [PATCH 010/135] Added the CompactList data structure to keep points and scalars --- nibabel/benchmarks/bench_streamlines.py | 33 ++-- nibabel/streamlines/base_format.py | 165 +++++++++++++++++- nibabel/streamlines/tests/test_base_format.py | 37 ++-- nibabel/streamlines/tests/test_streamlines.py | 2 +- nibabel/streamlines/tests/test_trk.py | 52 +++--- nibabel/streamlines/trk.py | 62 ++++++- 6 files changed, 281 insertions(+), 70 deletions(-) diff --git a/nibabel/benchmarks/bench_streamlines.py b/nibabel/benchmarks/bench_streamlines.py index 3a2e3ab39d..4eba60be7d 100644 --- a/nibabel/benchmarks/bench_streamlines.py +++ b/nibabel/benchmarks/bench_streamlines.py @@ -40,13 +40,21 @@ def bench_load_trk(): repeat = 20 trk_file = BytesIO() - trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES)) - tv.write(trk_file, trk) + #trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES)) + #tv.write(trk_file, trk) + streamlines = Streamlines(points) + TrkFile.save(streamlines, trk_file) + + from pycallgraph import PyCallGraph + from pycallgraph.output import GraphvizOutput - mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, lazy_load=False)', repeat) - print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) + with PyCallGraph(output=GraphvizOutput()): + #nib.streamlines.load(trk_file, ref=None, lazy_load=False) - mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file)', repeat) + mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, ref=None, lazy_load=False)', repeat) + print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) + + mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file, points_space="voxel")', repeat) print("Old: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) print("Speedup of %2f" % (mtime_old/mtime_new)) @@ -54,13 +62,15 @@ def bench_load_trk(): scalars = [np.random.rand(NB_POINTS, 10).astype('float32') for i in range(NB_STREAMLINES)] trk_file = BytesIO() - trk = list(zip(points, scalars, [None]*NB_STREAMLINES)) - tv.write(trk_file, trk) + #trk = list(zip(points, scalars, [None]*NB_STREAMLINES)) + #tv.write(trk_file, trk) + streamlines = Streamlines(points, scalars) + TrkFile.save(streamlines, trk_file) - mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, lazy_load=False)', repeat) + mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, ref=None, lazy_load=False)', repeat) print("New: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new)) - mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file)', repeat) + mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file, points_space="voxel")', repeat) print("Old: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old)) print("Speedup of %2f" % (mtime_old/mtime_new)) @@ -91,8 +101,7 @@ def bench_save_trk(): for pts, A in zip(points, streams): assert_array_equal(pts, A[0]) - trk = nib.streamlines.load(trk_file_new, lazy_load=False) - + trk = nib.streamlines.load(trk_file_new, ref=None, lazy_load=False) assert_arrays_equal(points, trk.points) # Points and scalars @@ -117,7 +126,7 @@ def bench_save_trk(): assert_array_equal(pts, A[0]) assert_array_equal(scal, A[1]) - trk = nib.streamlines.load(trk_file_new, lazy_load=False) + trk = nib.streamlines.load(trk_file_new, ref=None, lazy_load=False) assert_arrays_equal(points, trk.points) assert_arrays_equal(scalars, trk.scalars) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index e5068c69f8..8c6f3a9973 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -24,6 +24,128 @@ class DataError(Exception): pass +class CompactList(object): + def __init__(self, preallocate=(0,), dtype=np.float32): + self.dtype = dtype + self.data = np.empty(preallocate, dtype=dtype) + self.offsets = [] + self.lengths = [] + + @classmethod + def from_list(cls, elements): + """ Fast way to create a `Streamlines` object from some streamlines. + Parameters + ---------- + elements : list + List of 2D ndarrays of same shape except on the first dimension. + """ + if len(elements) == 0: + return cls() + + first_element = np.asarray(elements[0]) + s = cls(preallocate=(0,) + first_element.shape[1:], dtype=first_element.dtype) + s.extend(elements) + return s + + def append(self, element): + """ + Parameters + ---------- + element : 2D ndarrays of `element.shape[1:] == self.shape` + Element to add. + Note + ---- + If you need to add multiple elements you should consider + `CompactList.from_list` or `CompactList.extend`. + """ + self.offsets.append(len(self.data)) + self.lengths.append(len(element)) + self.data = np.append(self.data, element, axis=0) + + def extend(self, elements): + if isinstance(elements, CompactList): + self.data = np.concatenate([self.data, elements.data], axis=0) + offset = self.offsets[-1] + self.lengths[-1] if len(self) > 0 else 0 + self.lengths.extend(elements.lengths) + self.offsets.extend(np.cumsum([offset] + elements.lengths).tolist()[:-1]) + else: + self.data = np.concatenate([self.data] + list(elements), axis=0) + offset = self.offsets[-1] + self.lengths[-1] if len(self) > 0 else 0 + lengths = map(len, elements) + self.lengths.extend(lengths) + self.offsets.extend(np.cumsum([offset] + lengths).tolist()[:-1]) + + def __getitem__(self, idx): + """ Gets element(s) through indexing. + Parameters + ---------- + idx : int, slice or list + Index of the element(s) to get. + Returns + ------- + `ndarray` object(s) + When `idx` is a int, returns a single 2D array. + When `idx` is either a slice or a list, returns a list of 2D arrays. + """ + if isinstance(idx, int) or isinstance(idx, np.integer): + return self.data[self.offsets[idx]:self.offsets[idx]+self.lengths[idx]] + + elif type(idx) is slice: + compact_list = CompactList() + compact_list.data = self.data + compact_list.offsets = self.offsets[idx] + compact_list.lengths = self.lengths[idx] + return compact_list + + elif type(idx) is list: + compact_list = CompactList() + compact_list.data = self.data + compact_list.offsets = [self.offsets[i] for i in idx] + compact_list.lengths = [self.lengths[i] for i in idx] + return compact_list + + raise TypeError("Index must be a int or a slice! Not " + str(type(idx))) + + def copy(self): + # We could not only deepcopy the object because when slicing a CompactList it returns + # a view with modified `lengths` and `offsets` but `data` still points to the original data. + compact_list = CompactList() + total_lengths = np.sum(self.lengths) + compact_list.data = np.empty((total_lengths,) + self.data.shape[1:], dtype=self.dtype) + + cur_offset = 0 + for offset, lengths in zip(self.offsets, self.lengths): + compact_list.offsets.append(cur_offset) + compact_list.lengths.append(lengths) + compact_list.data[cur_offset:cur_offset+lengths] = self.data[offset:offset+lengths] + cur_offset += lengths + + return compact_list + + def __iter__(self): + if len(self.lengths) != len(self.offsets): + raise ValueError("CompactList object corrupted: len(self.lengths) != len(self.offsets)") + + for offset, lengths in zip(self.offsets, self.lengths): + yield self.data[offset: offset+lengths] + + def __len__(self): + return len(self.offsets) + + +class Streamline(object): + def __init__(self, points, scalars=None, properties=None): + self.points = points + self.scalars = scalars + self.properties = properties + + def __iter__(self): + return iter(self.points) + + def __len__(self): + return len(self.points) + + class Streamlines(object): ''' Class containing information about streamlines. @@ -64,7 +186,18 @@ def points(self): @points.setter def points(self, value): - self._points = value if value else [] + if value is None or len(value) == 0: + self._points = CompactList(preallocate=(0, 3)) + + elif isinstance(value, CompactList): + self._points = value + + elif isinstance(value, list) or isinstance(value, tuple): + self._points = CompactList.from_list(value) + + else: + raise DataError("Unsupported data type: {0}".format(type(value))) + self.header.nb_streamlines = len(self.points) @property @@ -73,9 +206,19 @@ def scalars(self): @scalars.setter def scalars(self, value): - self._scalars = value if value else [] - self.header.nb_scalars_per_point = 0 + if value is None or len(value) == 0: + self._scalars = CompactList() + + elif isinstance(value, CompactList): + self._scalars = value + + elif isinstance(value, list) or isinstance(value, tuple): + self._scalars = CompactList.from_list(value) + else: + raise DataError("Unsupported data type: {0}".format(type(value))) + + self.header.nb_scalars_per_point = 0 if len(self.scalars) > 0 and len(self.scalars[0]) > 0: self.header.nb_scalars_per_point = len(self.scalars[0][0]) @@ -85,14 +228,18 @@ def properties(self): @properties.setter def properties(self, value): - self._properties = value if value else [] + if value is None: + value = [] + + self._properties = np.asarray(value, dtype=np.float32) self.header.nb_properties_per_streamline = 0 if len(self.properties) > 0: self.header.nb_properties_per_streamline = len(self.properties[0]) def __iter__(self): - return zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) + for data in zip_longest(self.points, self.scalars, self.properties, fillvalue=[]): + yield Streamline(*data) def __getitem__(self, idx): pts = self.points[idx] @@ -105,16 +252,16 @@ def __getitem__(self, idx): properties = self.properties[idx] if type(idx) is slice: - return list(zip_longest(pts, scalars, properties, fillvalue=[])) + return Streamlines(pts, scalars, properties) - return pts, scalars, properties + return Streamline(pts, scalars, properties) def __len__(self): return len(self.points) def copy(self): """ Returns a copy of this `Streamlines` object. """ - streamlines = Streamlines(self.points, self.scalars, self.properties) + streamlines = Streamlines(self.points.copy(), self.scalars.copy(), self.properties.copy()) streamlines._header = self.header.copy() return streamlines @@ -280,7 +427,7 @@ def __getitem__(self, idx): def __iter__(self): i = 0 for i, s in enumerate(self._data(), start=1): - yield s + yield Streamline(*s) # To be safe, update information about number of streamlines. self.header.nb_streamlines = i diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index c8a9465265..f7164eca5e 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -48,7 +48,7 @@ def test_streamlines_creation_from_arrays(self): assert_arrays_equal(streamlines.properties, []) # Check if we can iterate through the streamlines. - for points, scalars, props in streamlines: + for streamline in streamlines: pass # Only points @@ -59,7 +59,7 @@ def test_streamlines_creation_from_arrays(self): assert_arrays_equal(streamlines.properties, []) # Check if we can iterate through the streamlines. - for points, scalars, props in streamlines: + for streamline in streamlines: pass # Points, scalars and properties @@ -70,9 +70,11 @@ def test_streamlines_creation_from_arrays(self): assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) # Check if we can iterate through the streamlines. - for point, scalar, prop in streamlines: + for streamline in streamlines: pass + #streamlines = Streamlines(self.points, scalars) + def test_streamlines_getter(self): # Streamlines with only points streamlines = Streamlines(points=self.points) @@ -80,27 +82,24 @@ def test_streamlines_getter(self): selected_streamlines = streamlines[::2] assert_equal(len(selected_streamlines), (len(self.points)+1)//2) - points, scalars, properties = zip(*selected_streamlines) - assert_arrays_equal(points, self.points[::2]) - assert_equal(sum(map(len, scalars)), 0) - assert_equal(sum(map(len, properties)), 0) + assert_arrays_equal(selected_streamlines.points, self.points[::2]) + assert_equal(sum(map(len, selected_streamlines.scalars)), 0) + assert_equal(sum(map(len, selected_streamlines.properties)), 0) # Streamlines with points, scalars and properties streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) # Retrieve streamlines by their index - for i, (points, scalars, props) in enumerate(streamlines): - points_i, scalars_i, props_i = streamlines[i] - assert_array_equal(points_i, points) - assert_array_equal(scalars_i, scalars) - assert_array_equal(props_i, props) + for i, streamline in enumerate(streamlines): + assert_array_equal(streamline.points, streamlines[i].points) + assert_array_equal(streamline.scalars, streamlines[i].scalars) + assert_array_equal(streamline.properties, streamlines[i].properties) # Use slicing r_streamlines = streamlines[::-1] - r_points, r_scalars, r_props = zip(*r_streamlines) - assert_arrays_equal(r_points, self.points[::-1]) - assert_arrays_equal(r_scalars, self.colors[::-1]) - assert_arrays_equal(r_props, self.mean_curvature_torsion[::-1]) + assert_arrays_equal(r_streamlines.points, self.points[::-1]) + assert_arrays_equal(r_streamlines.scalars, self.colors[::-1]) + assert_arrays_equal(r_streamlines.properties, self.mean_curvature_torsion[::-1]) def test_streamlines_creation_from_coroutines(self): # Points, scalars and properties @@ -199,7 +198,7 @@ def test_lazy_streamlines_creation(self): assert_arrays_equal(streamlines.properties, []) # Check if we can iterate through the streamlines. - for point, scalar, prop in streamlines: + for streamline in streamlines: pass # Points, scalars and properties @@ -230,7 +229,7 @@ def test_lazy_streamlines_creation(self): assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) # Check if we can iterate through the streamlines. - for point, scalar, prop in streamlines: + for streamline in streamlines: pass def test_lazy_streamlines_indexing(self): @@ -290,7 +289,7 @@ def test_lazy_streamlines_len(self): # Once we iterated through the streamlines, we know the length. streamlines = LazyStreamlines(points, scalars, properties) assert_true(streamlines.header.nb_streamlines is None) - for s in streamlines: + for streamline in streamlines: pass assert_equal(streamlines.header.nb_streamlines, len(self.points)) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 893331cb77..e6ab8cab17 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -22,7 +22,7 @@ def isiterable(streamlines): try: - for point, scalar, prop in streamlines: + for _ in streamlines: pass except: return False diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 1a1aa2ce05..e29b288dc5 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -20,7 +20,7 @@ def isiterable(streamlines): try: - for point, scalar, prop in streamlines: + for streamline in streamlines: pass except: return False @@ -201,29 +201,33 @@ def test_write_erroneous_file(self): streamlines = Streamlines(self.points, scalars) assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) - # Inconsistent number of scalars between points - scalars = [[(1, 0, 0)]*1, - [(0, 1, 0), (0, 1)], - [(0, 0, 1)]*5] - - streamlines = Streamlines(self.points, scalars) - assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) - - # Inconsistent number of scalars between streamlines - scalars = [[(1, 0, 0)]*1, - [(0, 1)]*2, - [(0, 0, 1)]*5] - - streamlines = Streamlines(self.points, scalars) - assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) - - # Inconsistent number of properties - properties = [np.array([1.11, 1.22], dtype="f4"), - np.array([2.11], dtype="f4"), - np.array([3.11, 3.22], dtype="f4")] - streamlines = Streamlines(self.points, properties=properties) - assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) - + # # Unit test moved to test_base_format.py + # # Inconsistent number of scalars between points + # scalars = [[(1, 0, 0)]*1, + # [(0, 1, 0), (0, 1)], + # [(0, 0, 1)]*5] + + #streamlines = Streamlines(self.points, scalars) + #assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) + + # # Unit test moved to test_base_format.py + # # Inconsistent number of scalars between streamlines + # scalars = [[(1, 0, 0)]*1, + # [(0, 1)]*2, + # [(0, 0, 1)]*5] + + # streamlines = Streamlines(self.points, scalars) + # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + # # Unit test moved to test_base_format.py + # # Inconsistent number of properties + # properties = [np.array([1.11, 1.22], dtype="f4"), + # np.array([2.11], dtype="f4"), + # np.array([3.11, 3.22], dtype="f4")] + # streamlines = Streamlines(self.points, properties=properties) + # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + + # # Unit test moved to test_base_format.py # No properties for every streamlines properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11, 2.22], dtype="f4")] diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 2a9ea9d29d..70df0385d0 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -12,6 +12,7 @@ from nibabel.openers import Opener from nibabel.volumeutils import (native_code, swapped_code) +from nibabel.streamlines.base_format import CompactList from nibabel.streamlines.base_format import StreamlinesFile from nibabel.streamlines.base_format import DataError, HeaderError, HeaderWarning from nibabel.streamlines.base_format import Streamlines, LazyStreamlines @@ -139,11 +140,23 @@ def __init__(self, fileobj): # Keep the file position where the data begin. self.offset_data = f.tell() + if f.name is not None and self.header[Field.NB_STREAMLINES] > 0: + filesize = os.path.getsize(f.name) - self.offset_data + # Remove properties + filesize -= self.header[Field.NB_STREAMLINES] * self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * 4. + # Remove the points count at the beginning of each streamline. + filesize -= self.header[Field.NB_STREAMLINES] * 4. + # Get nb points. + nb_points = filesize / ((3 + self.header[Field.NB_SCALARS_PER_POINT]) * 4.) + self.header[Field.NB_POINTS] = int(nb_points) + def __iter__(self): i4_dtype = np.dtype(self.endianness + "i4") f4_dtype = np.dtype(self.endianness + "f4") + #from io import BufferedReader with Opener(self.fileobj) as f: + #f = BufferedReader(f.fobj) start_position = f.tell() nb_pts_and_scalars = int(3 + self.header[Field.NB_SCALARS_PER_POINT]) @@ -257,13 +270,13 @@ def write(self, streamlines): i4_dtype = np.dtype("i4") f4_dtype = np.dtype("f4") - for points, scalars, properties in streamlines: - if len(scalars) > 0 and len(scalars) != len(points): + for s in streamlines: + if len(s.scalars) > 0 and len(s.scalars) != len(s.points): raise DataError("Missing scalars for some points!") - points = np.array(points, dtype=f4_dtype) - scalars = np.array(scalars, dtype=f4_dtype).reshape((len(points), -1)) - properties = np.array(properties, dtype=f4_dtype) + points = np.asarray(s.points, dtype=f4_dtype) + scalars = np.asarray(s.scalars, dtype=f4_dtype).reshape((len(points), -1)) + properties = np.asarray(s.properties, dtype=f4_dtype) # TRK's streamlines need to be in 'voxelmm' space points = points * self.header[Field.VOXEL_SIZES] @@ -413,6 +426,45 @@ def load(fileobj, ref=None, lazy_load=False): streamlines.scalars = lambda: [] if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: streamlines.properties = lambda: [] + elif Field.NB_POINTS in trk_reader.header: + # 'count' field is provided, we can avoid creating list of numpy + # arrays (more memory efficient). + + nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] + nb_points = trk_reader.header[Field.NB_POINTS] + + points = CompactList(preallocate=(nb_points, 3)) + scalars = CompactList(preallocate=(nb_points, trk_reader.header[Field.NB_SCALARS_PER_POINT])) + properties = np.empty((nb_streamlines, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]), + dtype=np.float32) + + offset = 0 + offsets = [] + lengths = [] + for i, (pts, scals, props) in enumerate(data()): + offsets.append(offset) + lengths.append(len(pts)) + try: + points.data[offset:offset+len(pts)] = pts + except: + from ipdb import set_trace as dbg + dbg() + scalars.data[offset:offset+len(scals)] = scals + properties[i] = props + offset += len(pts) + + points.offsets = offsets + scalars.offsets = offsets + points.lengths = lengths + scalars.lengths = lengths + streamlines = Streamlines(points, scalars, properties) + + # Overwrite scalars and properties if there is none + if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: + streamlines.scalars = [] + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: + streamlines.properties = [] + else: streamlines = Streamlines(*zip(*data())) From 7f715347a311930d2bb22c80835f4715baef5b98 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 30 Oct 2015 23:08:55 -0400 Subject: [PATCH 011/135] Old unit tests are all passing --- nibabel/__init__.py | 2 +- nibabel/benchmarks/bench_streamlines.py | 12 +- nibabel/streamlines/__init__.py | 2 +- nibabel/streamlines/base_format.py | 420 +++++++++++------- nibabel/streamlines/header.py | 4 +- nibabel/streamlines/tests/test_base_format.py | 305 ++++++++++--- nibabel/streamlines/tests/test_header.py | 4 +- nibabel/streamlines/tests/test_streamlines.py | 16 +- nibabel/streamlines/tests/test_trk.py | 24 +- nibabel/streamlines/trk.py | 168 +++---- 10 files changed, 633 insertions(+), 324 deletions(-) diff --git a/nibabel/__init__.py b/nibabel/__init__.py index 8cb1d95e2e..163bcd9e15 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -63,7 +63,7 @@ apply_orientation, aff2axcodes) from .imageclasses import class_map, ext_map, all_image_classes from . import trackvis -from .streamlines import Streamlines +from .streamlines import Tractogram from . import mriutils from . import viewers diff --git a/nibabel/benchmarks/bench_streamlines.py b/nibabel/benchmarks/bench_streamlines.py index 4eba60be7d..3d0da178c1 100644 --- a/nibabel/benchmarks/bench_streamlines.py +++ b/nibabel/benchmarks/bench_streamlines.py @@ -45,14 +45,14 @@ def bench_load_trk(): streamlines = Streamlines(points) TrkFile.save(streamlines, trk_file) - from pycallgraph import PyCallGraph - from pycallgraph.output import GraphvizOutput + # from pycallgraph import PyCallGraph + # from pycallgraph.output import GraphvizOutput - with PyCallGraph(output=GraphvizOutput()): - #nib.streamlines.load(trk_file, ref=None, lazy_load=False) + # with PyCallGraph(output=GraphvizOutput()): + # #nib.streamlines.load(trk_file, ref=None, lazy_load=False) - mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, ref=None, lazy_load=False)', repeat) - print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) + mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, ref=None, lazy_load=False)', repeat) + print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file, points_space="voxel")', repeat) print("Old: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 7f3d7dcdf2..e23757013f 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,5 +1,5 @@ from .header import Field -from .base_format import Streamlines, LazyStreamlines +from .base_format import Tractogram, LazyTractogram from nibabel.streamlines.trk import TrkFile #from nibabel.streamlines.tck import TckFile diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 8c6f3a9973..95fc2669b5 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -1,10 +1,11 @@ +import itertools import numpy as np from warnings import warn from nibabel.externals.six.moves import zip_longest from nibabel.affines import apply_affine -from .header import StreamlinesHeader +from .header import TractogramHeader from .utils import pop @@ -25,119 +26,198 @@ class DataError(Exception): class CompactList(object): - def __init__(self, preallocate=(0,), dtype=np.float32): - self.dtype = dtype - self.data = np.empty(preallocate, dtype=dtype) - self.offsets = [] - self.lengths = [] - - @classmethod - def from_list(cls, elements): - """ Fast way to create a `Streamlines` object from some streamlines. + """ Class for compacting list of ndarrays with matching shape except for + the first dimension. + """ + def __init__(self, iterable=None): + """ Parameters ---------- - elements : list - List of 2D ndarrays of same shape except on the first dimension. + iterable : iterable (optional) + If specified, create a ``CompactList`` object initialized from + iterable's items. Otherwise, create an empty ``CompactList``. """ - if len(elements) == 0: - return cls() + # Create new empty `CompactList` object. + self._data = None + self._offsets = [] + self._lengths = [] + + if iterable is not None: + # Initialize the `CompactList` object from iterable's item. + BUFFER_SIZE = 1000 + + offset = 0 + for i, e in enumerate(iterable): + e = np.asarray(e) + if i == 0: + self._data = np.empty((BUFFER_SIZE,) + e.shape[1:], dtype=e.dtype) + + end = offset + len(e) + if end >= len(self._data): + # Resize is needed (at least `len(e)` items will be added). + self._data.resize((len(self._data) + len(e)+BUFFER_SIZE,) + self.shape) + + self._offsets.append(offset) + self._lengths.append(len(e)) + self._data[offset:offset+len(e)] = e + offset += len(e) + + # Clear unused memory. + if self._data is not None: + self._data.resize((offset,) + self.shape) - first_element = np.asarray(elements[0]) - s = cls(preallocate=(0,) + first_element.shape[1:], dtype=first_element.dtype) - s.extend(elements) - return s + @property + def shape(self): + """ Returns the matching shape of the elements in this compact list. """ + if self._data is None: + return None + + return self._data.shape[1:] def append(self, element): - """ + """ Appends `element` to this compact list. + Parameters ---------- - element : 2D ndarrays of `element.shape[1:] == self.shape` - Element to add. - Note - ---- + element : ndarray + Element to append. The shape must match already inserted elements + shape except for the first dimension. + + Notes + ----- If you need to add multiple elements you should consider - `CompactList.from_list` or `CompactList.extend`. + `CompactList.extend`. """ - self.offsets.append(len(self.data)) - self.lengths.append(len(element)) - self.data = np.append(self.data, element, axis=0) + if self._data is None: + self._data = np.asarray(element).copy() + self._offsets.append(0) + self._lengths.append(len(element)) + return + + if element.shape[1:] != self.shape: + raise ValueError("All dimensions, except the first one, must match exactly") + + self._offsets.append(len(self._data)) + self._lengths.append(len(element)) + self._data = np.append(self._data, element, axis=0) def extend(self, elements): + """ Appends all `elements` to this compact list. + + Parameters + ---------- + element : list of ndarrays, ``CompactList`` object + Elements to append. The shape must match already inserted elements + shape except for the first dimension. + """ if isinstance(elements, CompactList): - self.data = np.concatenate([self.data, elements.data], axis=0) - offset = self.offsets[-1] + self.lengths[-1] if len(self) > 0 else 0 - self.lengths.extend(elements.lengths) - self.offsets.extend(np.cumsum([offset] + elements.lengths).tolist()[:-1]) + self._data = np.concatenate([self._data, elements._data], axis=0) + offset = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 + self._lengths.extend(elements._lengths) + self._offsets.extend(np.cumsum([offset] + elements._lengths).tolist()[:-1]) else: - self.data = np.concatenate([self.data] + list(elements), axis=0) - offset = self.offsets[-1] + self.lengths[-1] if len(self) > 0 else 0 + self._data = np.concatenate([self._data] + list(elements), axis=0) + offset = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 lengths = map(len, elements) - self.lengths.extend(lengths) - self.offsets.extend(np.cumsum([offset] + lengths).tolist()[:-1]) + self._lengths.extend(lengths) + self._offsets.extend(np.cumsum([offset] + lengths).tolist()[:-1]) + + def copy(self): + """ Creates a copy of this ``CompactList`` object. """ + # We cannot just deepcopy this object since we don't know if it has been created + # using slicing. If it is the case, `self.data` probably contains more data than necessary + # so we copy only elements according to `self._offsets`. + compact_list = CompactList() + total_lengths = np.sum(self._lengths) + compact_list._data = np.empty((total_lengths,) + self._data.shape[1:], dtype=self._data.dtype) + + cur_offset = 0 + for offset, lengths in zip(self._offsets, self._lengths): + compact_list._offsets.append(cur_offset) + compact_list._lengths.append(lengths) + compact_list._data[cur_offset:cur_offset+lengths] = self._data[offset:offset+lengths] + cur_offset += lengths + + return compact_list def __getitem__(self, idx): """ Gets element(s) through indexing. + Parameters ---------- idx : int, slice or list Index of the element(s) to get. + Returns ------- - `ndarray` object(s) - When `idx` is a int, returns a single 2D array. - When `idx` is either a slice or a list, returns a list of 2D arrays. + ndarray object(s) + When `idx` is a int, returns a single ndarray. + When `idx` is either a slice or a list, returns a list of ndarrays. """ if isinstance(idx, int) or isinstance(idx, np.integer): - return self.data[self.offsets[idx]:self.offsets[idx]+self.lengths[idx]] + return self._data[self._offsets[idx]:self._offsets[idx]+self._lengths[idx]] elif type(idx) is slice: + # TODO: Should we have a CompactListView class that would be + # returned when slicing? compact_list = CompactList() - compact_list.data = self.data - compact_list.offsets = self.offsets[idx] - compact_list.lengths = self.lengths[idx] + compact_list._data = self._data + compact_list._offsets = self._offsets[idx] + compact_list._lengths = self._lengths[idx] return compact_list elif type(idx) is list: + # TODO: Should we have a CompactListView class that would be + # returned when doing advance indexing? compact_list = CompactList() - compact_list.data = self.data - compact_list.offsets = [self.offsets[i] for i in idx] - compact_list.lengths = [self.lengths[i] for i in idx] + compact_list._data = self._data + compact_list._offsets = [self._offsets[i] for i in idx] + compact_list._lengths = [self._lengths[i] for i in idx] return compact_list raise TypeError("Index must be a int or a slice! Not " + str(type(idx))) - def copy(self): - # We could not only deepcopy the object because when slicing a CompactList it returns - # a view with modified `lengths` and `offsets` but `data` still points to the original data. - compact_list = CompactList() - total_lengths = np.sum(self.lengths) - compact_list.data = np.empty((total_lengths,) + self.data.shape[1:], dtype=self.dtype) - - cur_offset = 0 - for offset, lengths in zip(self.offsets, self.lengths): - compact_list.offsets.append(cur_offset) - compact_list.lengths.append(lengths) - compact_list.data[cur_offset:cur_offset+lengths] = self.data[offset:offset+lengths] - cur_offset += lengths - - return compact_list - def __iter__(self): - if len(self.lengths) != len(self.offsets): - raise ValueError("CompactList object corrupted: len(self.lengths) != len(self.offsets)") + if len(self._lengths) != len(self._offsets): + raise ValueError("CompactList object corrupted: len(self._lengths) != len(self._offsets)") - for offset, lengths in zip(self.offsets, self.lengths): - yield self.data[offset: offset+lengths] + for offset, lengths in zip(self._offsets, self._lengths): + yield self._data[offset: offset+lengths] def __len__(self): - return len(self.offsets) + return len(self._offsets) + def __repr__(self): + return repr(list(self)) -class Streamline(object): +class TractogramItem(object): + ''' Class containing information about one streamline. + + ``TractogramItem`` objects have three main properties: `points`, `scalars` + and ``properties``. + + Parameters + ---------- + points : ndarray of shape (N, 3) + Points of this streamline represented as an ndarray of shape (N, 3) + where N is the number of points. + + scalars : ndarray of shape (N, M) + Scalars associated with each point of this streamline and represented + as an ndarray of shape (N, M) where N is the number of points and + M is the number of scalars (excluding the three coordinates). + + properties : ndarray of shape (P,) + Properties associated with this streamline and represented as an + ndarray of shape (P,) where P is the number of properties. + ''' def __init__(self, points, scalars=None, properties=None): - self.points = points - self.scalars = scalars - self.properties = properties + #if scalars is not None and len(points) != len(scalars): + # raise ValueError("First dimension of points and scalars must match.") + + self.points = np.asarray(points) + self.scalars = np.asarray([] if scalars is None else scalars) + self.properties = np.asarray([] if properties is None else properties) def __iter__(self): return iter(self.points) @@ -146,11 +226,11 @@ def __len__(self): return len(self.points) -class Streamlines(object): +class Tractogram(object): ''' Class containing information about streamlines. - Streamlines objects have three main properties: ``points``, ``scalars`` - and ``properties``. Streamlines objects can be iterate over producing + Tractogram objects have three main properties: ``points``, ``scalars`` + and ``properties``. Tractogram objects can be iterate over producing tuple of ``points``, ``scalars`` and ``properties`` for each streamline. Parameters @@ -171,11 +251,71 @@ class Streamlines(object): associated to each streamline. ''' def __init__(self, points=None, scalars=None, properties=None): - self._header = StreamlinesHeader() + self._header = TractogramHeader() self.points = points self.scalars = scalars self.properties = properties + @classmethod + def create_from_generator(cls, gen): + BUFFER_SIZE = 1000 + + points = CompactList() + scalars = CompactList() + properties = np.array([]) + + gen = iter(gen) + try: + first_element = next(gen) + gen = itertools.chain([first_element], gen) + except StopIteration: + return cls(points, scalars, properties) + + # Allocated some buffer memory. + pts = np.asarray(first_element[0]) + scals = np.asarray(first_element[1]) + props = np.asarray(first_element[2]) + + points._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) + scalars._data = np.empty((BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) + properties = np.empty((BUFFER_SIZE, props.shape[0]), dtype=props.dtype) + + offset = 0 + for i, (pts, scals, props) in enumerate(gen): + pts = np.asarray(pts) + scals = np.asarray(scals) + props = np.asarray(props) + + end = offset + len(pts) + if end >= len(points._data): + # Resize is needed (at least `len(pts)` items will be added). + points._data.resize((len(points._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) + scalars._data.resize((len(scalars._data) + len(scalars)+BUFFER_SIZE, scals.shape[1])) + + points._offsets.append(offset) + points._lengths.append(len(pts)) + points._data[offset:offset+len(pts)] = pts + scalars._data[offset:offset+len(scals)] = scals + + offset += len(pts) + + if i >= len(properties): + properties.resize((len(properties) + BUFFER_SIZE, props.shape[0])) + + properties[i] = props + + # Share offsets and lengths between points and scalars. + scalars._offsets = points._offsets + scalars._lengths = points._lengths + + # Clear unused memory. + points._data.resize((offset, pts.shape[1])) + scalars._data.resize((offset, scals.shape[1])) + properties.resize((i+1, props.shape[0])) + + return cls(points, scalars, properties) + + @property def header(self): return self._header @@ -186,17 +326,9 @@ def points(self): @points.setter def points(self, value): - if value is None or len(value) == 0: - self._points = CompactList(preallocate=(0, 3)) - - elif isinstance(value, CompactList): - self._points = value - - elif isinstance(value, list) or isinstance(value, tuple): - self._points = CompactList.from_list(value) - - else: - raise DataError("Unsupported data type: {0}".format(type(value))) + self._points = value + if not isinstance(value, CompactList): + self._points = CompactList(value) self.header.nb_streamlines = len(self.points) @@ -206,17 +338,9 @@ def scalars(self): @scalars.setter def scalars(self, value): - if value is None or len(value) == 0: - self._scalars = CompactList() - - elif isinstance(value, CompactList): - self._scalars = value - - elif isinstance(value, list) or isinstance(value, tuple): - self._scalars = CompactList.from_list(value) - - else: - raise DataError("Unsupported data type: {0}".format(type(value))) + self._scalars = value + if not isinstance(value, CompactList): + self._scalars = CompactList(value) self.header.nb_scalars_per_point = 0 if len(self.scalars) > 0 and len(self.scalars[0]) > 0: @@ -228,18 +352,17 @@ def properties(self): @properties.setter def properties(self, value): + self._properties = np.asarray(value) if value is None: - value = [] + self._properties = np.empty((len(self), 0), dtype=np.float32) - self._properties = np.asarray(value, dtype=np.float32) self.header.nb_properties_per_streamline = 0 - if len(self.properties) > 0: self.header.nb_properties_per_streamline = len(self.properties[0]) def __iter__(self): - for data in zip_longest(self.points, self.scalars, self.properties, fillvalue=[]): - yield Streamline(*data) + for data in zip_longest(self.points, self.scalars, self.properties, fillvalue=None): + yield TractogramItem(*data) def __getitem__(self, idx): pts = self.points[idx] @@ -252,70 +375,43 @@ def __getitem__(self, idx): properties = self.properties[idx] if type(idx) is slice: - return Streamlines(pts, scalars, properties) + return Tractogram(pts, scalars, properties) - return Streamline(pts, scalars, properties) + return TractogramItem(pts, scalars, properties) def __len__(self): return len(self.points) def copy(self): - """ Returns a copy of this `Streamlines` object. """ - streamlines = Streamlines(self.points.copy(), self.scalars.copy(), self.properties.copy()) + """ Returns a copy of this `Tractogram` object. """ + streamlines = Tractogram(self.points.copy(), self.scalars.copy(), self.properties.copy()) streamlines._header = self.header.copy() return streamlines - def transform(self, affine, lazy=False): + def apply_affine(self, affine): """ Applies an affine transformation on the points of each streamline. + This is performed in-place. + Parameters ---------- affine : 2D array (4,4) Transformation that will be applied on each streamline. - lazy : bool (optional) - If true output will be a generator of arrays instead of a list. - - Returns - ------- - streamlines - If `lazy` is true, a `LazyStreamlines` object is returned, - otherwise a `Streamlines` object is returned. In both case, - streamlines are in a space defined by `affine`. """ - points = lambda: (apply_affine(affine, pts) for pts in self.points) - if not lazy: - points = list(points()) - - streamlines = self.copy() - streamlines.points = points - streamlines.header.to_world_space = np.dot(streamlines.header.to_world_space, - np.linalg.inv(affine)) + if len(self.points) == 0: + return - return streamlines - - def to_world_space(self, lazy=False): - """ Sends the streamlines back into world space. - - Parameters - ---------- - lazy : bool (optional) - If true output will be a generator of arrays instead of a list. - - Returns - ------- - streamlines - If `lazy` is true, a `LazyStreamlines` object is returned, - otherwise a `Streamlines` object is returned. In both case, - streamlines are in world space. - """ - return self.transform(self.header.to_world_space, lazy) + BUFFER_SIZE = 10000 + for i in range(0, len(self.points._data), BUFFER_SIZE): + pts = self.points._data[i:i+BUFFER_SIZE] + self.points._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) -class LazyStreamlines(Streamlines): +class LazyTractogram(Tractogram): ''' Class containing information about streamlines. - Streamlines objects have four main properties: ``header``, ``points``, - ``scalars`` and ``properties``. Streamlines objects are iterable and + Tractogram objects have four main properties: ``header``, ``points``, + ``scalars`` and ``properties``. Tractogram objects are iterable and produce tuple of ``points``, ``scalars`` and ``properties`` for each streamline. @@ -347,7 +443,7 @@ class LazyStreamlines(Streamlines): values as ``points``. ''' def __init__(self, points_func=lambda:[], scalars_func=lambda: [], properties_func=lambda: [], getitem_func=None): - super(LazyStreamlines, self).__init__(points_func, scalars_func, properties_func) + super(LazyTractogram, self).__init__(points_func, scalars_func, properties_func) self._data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) self._getitem = getitem_func @@ -420,14 +516,14 @@ def properties(self, value): def __getitem__(self, idx): if self._getitem is None: - raise AttributeError('`LazyStreamlines` does not support indexing.') + raise AttributeError('`LazyTractogram` does not support indexing.') return self._getitem(idx) def __iter__(self): i = 0 for i, s in enumerate(self._data(), start=1): - yield Streamline(*s) + yield TractogramItem(*s) # To be safe, update information about number of streamlines. self.header.nb_streamlines = i @@ -440,14 +536,14 @@ def __len__(self): " streamlines, you might want to set it beforehand via" " `self.header.nb_streamlines`." " Note this will consume any generators used to create this" - " `LazyStreamlines` object.", UsageWarning) + " `LazyTractogram` object.", UsageWarning) return sum(1 for _ in self) return self.header.nb_streamlines def copy(self): - """ Returns a copy of this `LazyStreamlines` object. """ - streamlines = LazyStreamlines(self._points, self._scalars, self._properties) + """ Returns a copy of this `LazyTractogram` object. """ + streamlines = LazyTractogram(self._points, self._scalars, self._properties) streamlines._header = self.header.copy() return streamlines @@ -461,23 +557,23 @@ def transform(self, affine): Returns ------- - streamlines : `LazyStreamlines` object - Streamlines living in a space defined by `affine`. + streamlines : `LazyTractogram` object + Tractogram living in a space defined by `affine`. """ - return super(LazyStreamlines, self).transform(affine, lazy=True) + return super(LazyTractogram, self).transform(affine, lazy=True) def to_world_space(self): """ Sends the streamlines back into world space. Returns ------- - streamlines : `LazyStreamlines` object - Streamlines living in world space. + streamlines : `LazyTractogram` object + Tractogram living in world space. """ - return super(LazyStreamlines, self).to_world_space(lazy=True) + return super(LazyTractogram, self).to_world_space(lazy=True) -class StreamlinesFile: +class TractogramFile: ''' Convenience class to encapsulate streamlines file format. ''' @classmethod @@ -533,9 +629,9 @@ def load(fileobj, ref, lazy_load=True): Returns ------- - streamlines : Streamlines object + streamlines : Tractogram object Returns an object containing streamlines' data and header - information. See 'nibabel.Streamlines'. + information. See 'nibabel.Tractogram'. ''' raise NotImplementedError() @@ -545,9 +641,9 @@ def save(streamlines, fileobj, ref=None): Parameters ---------- - streamlines : Streamlines object + streamlines : Tractogram object Object containing streamlines' data and header information. - See 'nibabel.Streamlines'. + See 'nibabel.Tractogram'. fileobj : string or file-like object If string, a filename; otherwise an open file-like object @@ -577,7 +673,7 @@ def pretty_print(streamlines): raise NotImplementedError() -# class DynamicStreamlineFile(StreamlinesFile): +# class DynamicTractogramFile(TractogramFile): # ''' Convenience class to encapsulate streamlines file format # that supports appending streamlines to an existing file. # ''' diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 7a7ec63b3c..706afd348a 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -24,7 +24,7 @@ class Field: ENDIAN = "endian" -class StreamlinesHeader(object): +class TractogramHeader(object): def __init__(self): self._nb_streamlines = None self._nb_scalars_per_point = None @@ -90,7 +90,7 @@ def extra(self, value): self._extra = OrderedDict(value) def copy(self): - header = StreamlinesHeader() + header = TractogramHeader() header._nb_streamlines = self.nb_streamlines header.nb_scalars_per_point = self.nb_scalars_per_point header.nb_properties_per_streamline = self.nb_properties_per_streamline diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index f7164eca5e..b699066c1c 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -7,16 +7,232 @@ from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal, assert_array_almost_equal -from nibabel.externals.six.moves import zip +from nibabel.externals.six.moves import zip, zip_longest from .. import base_format -from ..base_format import Streamlines, LazyStreamlines +from ..base_format import CompactList +from ..base_format import TractogramItem, Tractogram, LazyTractogram from ..base_format import UsageWarning DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') -class TestStreamlines(unittest.TestCase): +class TestCompactList(unittest.TestCase): + + def setUp(self): + rng = np.random.RandomState(42) + self.data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + self.lengths = map(len, self.data) + self.clist = CompactList(self.data) + + def test_creating_empty_compactlist(self): + clist = CompactList() + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_list(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = map(len, data) + + clist = CompactList(data) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + + # Empty list + clist = CompactList([]) + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_generator(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = map(len, data) + + gen = (e for e in data) + clist = CompactList(gen) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + + # Already consumed generator + clist = CompactList(gen) + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_compact_list(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = map(len, data) + + clist = CompactList(data) + clist2 = CompactList(clist) + assert_equal(len(clist2), len(data)) + assert_equal(len(clist2._offsets), len(data)) + assert_equal(len(clist2._lengths), len(data)) + assert_equal(clist2._data.shape[0], sum(lengths)) + assert_equal(clist2._data.shape[1], 3) + assert_equal(clist2._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist2._lengths, lengths) + assert_equal(clist2.shape, data[0].shape[1:]) + + def test_compactlist_iter(self): + for e, d in zip(self.clist, self.data): + assert_array_equal(e, d) + + def test_compactlist_copy(self): + clist = self.clist.copy() + assert_array_equal(clist._data, self.clist._data) + assert_true(clist._data is not self.clist._data) + assert_array_equal(clist._offsets, self.clist._offsets) + assert_true(clist._offsets is not self.clist._offsets) + assert_array_equal(clist._lengths, self.clist._lengths) + assert_true(clist._lengths is not self.clist._lengths) + + assert_equal(clist.shape, self.clist.shape) + + # When taking a copy of a `CompactList` generated by slicing. + # Only needed data should be kept. + clist = self.clist[::2].copy() + + assert_true(clist._data.shape[0] < self.clist._data.shape[0]) + assert_true(len(clist) < len(self.clist)) + assert_true(clist._data is not self.clist._data) + + def test_compactlist_append(self): + # Maybe not necessary if `self.setUp` is always called before a + # test method, anyways create a copy just in case. + clist = self.clist.copy() + + rng = np.random.RandomState(1234) + element = rng.rand(rng.randint(10, 50), *self.clist.shape) + clist.append(element) + assert_equal(len(clist), len(self.clist)+1) + assert_equal(clist._offsets[-1], len(self.clist._data)) + assert_equal(clist._lengths[-1], len(element)) + assert_array_equal(clist._data[-len(element):], element) + + # Append with different shape. + element = rng.rand(rng.randint(10, 50), 42) + assert_raises(ValueError, clist.append, element) + + # Append to an empty CompactList. + clist = CompactList() + rng = np.random.RandomState(1234) + shape = (2, 3, 4) + element = rng.rand(rng.randint(10, 50), *shape) + clist.append(element) + + assert_equal(len(clist), 1) + assert_equal(clist._offsets[-1], 0) + assert_equal(clist._lengths[-1], len(element)) + assert_array_equal(clist._data, element) + assert_equal(clist.shape, shape) + + def test_compactlist_extend(self): + # Maybe not necessary if `self.setUp` is always called before a + # test method, anyways create a copy just in case. + clist = self.clist.copy() + + rng = np.random.RandomState(1234) + shape = self.clist.shape + new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(5)] + lengths = map(len, new_data) + clist.extend(new_data) + assert_equal(len(clist), len(self.clist)+len(new_data)) + assert_array_equal(clist._offsets[-len(new_data):], + len(self.clist._data) + np.cumsum([0] + lengths[:-1])) + + assert_equal(clist._lengths[-len(new_data):], lengths) + assert_array_equal(clist._data[-sum(lengths):], + np.concatenate(new_data, axis=0)) + + # Extend with another `CompactList` object. + clist = self.clist.copy() + new_data = CompactList(new_data) + clist.extend(new_data) + assert_equal(len(clist), len(self.clist)+len(new_data)) + assert_array_equal(clist._offsets[-len(new_data):], + len(self.clist._data) + np.cumsum([0] + lengths[:-1])) + + assert_equal(clist._lengths[-len(new_data):], lengths) + assert_array_equal(clist._data[-sum(lengths):], new_data._data) + + def test_compactlist_getitem(self): + # Get one item + for i, e in enumerate(self.clist): + assert_array_equal(self.clist[i], e) + + # Get multiple items (this will create a view). + clist_view = self.clist[range(len(self.clist))] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_true(clist_view._offsets is not self.clist._offsets) + assert_true(clist_view._lengths is not self.clist._lengths) + assert_array_equal(clist_view._offsets, self.clist._offsets) + assert_array_equal(clist_view._lengths, self.clist._lengths) + for e1, e2 in zip_longest(clist_view, self.clist): + assert_array_equal(e1, e2) + + # Get slice (this will create a view). + clist_view = self.clist[::2] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_array_equal(clist_view._offsets, self.clist._offsets[::2]) + assert_array_equal(clist_view._lengths, self.clist._lengths[::2]) + for i, e in enumerate(clist_view): + assert_array_equal(e, self.clist[i*2]) + + +class TestTractogramItem(unittest.TestCase): + + def test_creating_tractogram_item(self): + rng = np.random.RandomState(42) + points = rng.rand(rng.randint(10, 50), 3) + scalars = rng.rand(len(points), 5) + properties = rng.rand(42) + + # Create a streamline with only points + s = TractogramItem(points) + assert_equal(len(s), len(points)) + assert_array_equal(s.scalars, []) + assert_array_equal(s.properties, []) + + # Create a streamline with points, scalars and properties. + s = TractogramItem(points, scalars, properties) + assert_equal(len(s), len(points)) + assert_array_equal(s.points, points) + assert_array_equal(list(s), points) + assert_equal(len(s), len(scalars)) + assert_array_equal(s.scalars, scalars) + assert_array_equal(s.properties, properties) + + # # Create a streamline with different number of scalars. + # scalars = rng.rand(len(points)+3, 5) + # assert_raises(ValueError, TractogramItem, points, scalars) + + +class TestTractogram(unittest.TestCase): def setUp(self): self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") @@ -41,7 +257,7 @@ def setUp(self): def test_streamlines_creation_from_arrays(self): # Empty - streamlines = Streamlines() + streamlines = Tractogram() assert_equal(len(streamlines), 0) assert_arrays_equal(streamlines.points, []) assert_arrays_equal(streamlines.scalars, []) @@ -52,7 +268,7 @@ def test_streamlines_creation_from_arrays(self): pass # Only points - streamlines = Streamlines(points=self.points) + streamlines = Tractogram(points=self.points) assert_equal(len(streamlines), len(self.points)) assert_arrays_equal(streamlines.points, self.points) assert_arrays_equal(streamlines.scalars, []) @@ -63,7 +279,7 @@ def test_streamlines_creation_from_arrays(self): pass # Points, scalars and properties - streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) + streamlines = Tractogram(self.points, self.colors, self.mean_curvature_torsion) assert_equal(len(streamlines), len(self.points)) assert_arrays_equal(streamlines.points, self.points) assert_arrays_equal(streamlines.scalars, self.colors) @@ -73,11 +289,11 @@ def test_streamlines_creation_from_arrays(self): for streamline in streamlines: pass - #streamlines = Streamlines(self.points, scalars) + #streamlines = Tractogram(self.points, scalars) def test_streamlines_getter(self): - # Streamlines with only points - streamlines = Streamlines(points=self.points) + # Tractogram with only points + streamlines = Tractogram(points=self.points) selected_streamlines = streamlines[::2] assert_equal(len(selected_streamlines), (len(self.points)+1)//2) @@ -86,8 +302,8 @@ def test_streamlines_getter(self): assert_equal(sum(map(len, selected_streamlines.scalars)), 0) assert_equal(sum(map(len, selected_streamlines.properties)), 0) - # Streamlines with points, scalars and properties - streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) + # Tractogram with points, scalars and properties + streamlines = Tractogram(self.points, self.colors, self.mean_curvature_torsion) # Retrieve streamlines by their index for i, streamline in enumerate(streamlines): @@ -107,23 +323,12 @@ def test_streamlines_creation_from_coroutines(self): scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - # To create streamlines from coroutines use `LazyStreamlines`. - assert_raises(TypeError, Streamlines, points, scalars, properties) - - def test_to_world_space(self): - streamlines = Streamlines(self.points) - - # World space is (RAS+) with voxel size of 2x3x4mm. - streamlines.header.voxel_sizes = (2, 3, 4) - - new_streamlines = streamlines.to_world_space() - for new_pts, pts in zip(new_streamlines.points, self.points): - for dim, size in enumerate(streamlines.header.voxel_sizes): - assert_array_almost_equal(new_pts[:, dim], size*pts[:, dim]) + # To create streamlines from coroutines use `LazyTractogram`. + assert_raises(TypeError, Tractogram, points, scalars, properties) def test_header(self): - # Empty Streamlines, with default header - streamlines = Streamlines() + # Empty Tractogram, with default header + streamlines = Tractogram() assert_equal(streamlines.header.nb_streamlines, 0) assert_equal(streamlines.header.nb_scalars_per_point, 0) assert_equal(streamlines.header.nb_properties_per_streamline, 0) @@ -131,7 +336,7 @@ def test_header(self): assert_array_equal(streamlines.header.to_world_space, np.eye(4)) assert_equal(streamlines.header.extra, {}) - streamlines = Streamlines(self.points, self.colors, self.mean_curvature_torsion) + streamlines = Tractogram(self.points, self.colors, self.mean_curvature_torsion) assert_equal(streamlines.header.nb_streamlines, len(self.points)) assert_equal(streamlines.header.nb_scalars_per_point, self.colors[0].shape[1]) @@ -151,7 +356,7 @@ def test_header(self): repr(streamlines.header) -class TestLazyStreamlines(unittest.TestCase): +class TestLazyTractogram(unittest.TestCase): def setUp(self): self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") @@ -175,22 +380,22 @@ def setUp(self): self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) def test_lazy_streamlines_creation(self): - # To create streamlines from arrays use `Streamlines`. - assert_raises(TypeError, LazyStreamlines, self.points) + # To create streamlines from arrays use `Tractogram`. + assert_raises(TypeError, LazyTractogram, self.points) # Points, scalars and properties points = (x for x in self.points) scalars = (x for x in self.colors) properties = (x for x in self.mean_curvature_torsion) - # Creating LazyStreamlines from generators is not allowed as + # Creating LazyTractogram from generators is not allowed as # generators get exhausted and are not reusable unline coroutines. - assert_raises(TypeError, LazyStreamlines, points) - assert_raises(TypeError, LazyStreamlines, self.points, scalars) - assert_raises(TypeError, LazyStreamlines, properties_func=properties) + assert_raises(TypeError, LazyTractogram, points) + assert_raises(TypeError, LazyTractogram, self.points, scalars) + assert_raises(TypeError, LazyTractogram, properties_func=properties) - # Empty `LazyStreamlines` - streamlines = LazyStreamlines() + # Empty `LazyTractogram` + streamlines = LazyTractogram() with suppress_warnings(): assert_equal(len(streamlines), 0) assert_arrays_equal(streamlines.points, []) @@ -206,7 +411,7 @@ def test_lazy_streamlines_creation(self): scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - streamlines = LazyStreamlines(points, scalars, properties) + streamlines = LazyTractogram(points, scalars, properties) with suppress_warnings(): assert_equal(len(streamlines), self.nb_streamlines) @@ -218,10 +423,10 @@ def test_lazy_streamlines_creation(self): assert_arrays_equal(streamlines.scalars, self.colors) assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) - # Create `LazyStreamlines` from a coroutine yielding 3-tuples + # Create `LazyTractogram` from a coroutine yielding 3-tuples data = lambda: (x for x in zip(self.points, self.colors, self.mean_curvature_torsion)) - streamlines = LazyStreamlines.create_from_data(data) + streamlines = LazyTractogram.create_from_data(data) with suppress_warnings(): assert_equal(len(streamlines), self.nb_streamlines) assert_arrays_equal(streamlines.points, self.points) @@ -237,18 +442,18 @@ def test_lazy_streamlines_indexing(self): scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - # By default, `LazyStreamlines` object does not support indexing. - streamlines = LazyStreamlines(points, scalars, properties) + # By default, `LazyTractogram` object does not support indexing. + streamlines = LazyTractogram(points, scalars, properties) assert_raises(AttributeError, streamlines.__getitem__, 0) - # Create a `LazyStreamlines` object with indexing support. + # Create a `LazyTractogram` object with indexing support. def getitem_without_properties(idx): if isinstance(idx, int) or isinstance(idx, np.integer): return self.points[idx], self.colors[idx] return list(zip(self.points[idx], self.colors[idx])) - streamlines = LazyStreamlines(points, scalars, properties, getitem_without_properties) + streamlines = LazyTractogram(points, scalars, properties, getitem_without_properties) points, scalars = streamlines[0] assert_array_equal(points, self.points[0]) assert_array_equal(scalars, self.colors[0]) @@ -270,12 +475,12 @@ def test_lazy_streamlines_len(self): warnings.simplefilter("always") # Always trigger warnings. # Calling `len` will create new generators each time. - streamlines = LazyStreamlines(points, scalars, properties) + streamlines = LazyTractogram(points, scalars, properties) # This should produce a warning message. assert_equal(len(streamlines), self.nb_streamlines) assert_equal(len(w), 1) - streamlines = LazyStreamlines(points, scalars, properties) + streamlines = LazyTractogram(points, scalars, properties) # This should still produce a warning message. assert_equal(len(streamlines), self.nb_streamlines) assert_equal(len(w), 2) @@ -287,7 +492,7 @@ def test_lazy_streamlines_len(self): with clear_and_catch_warnings(record=True, modules=[base_format]) as w: # Once we iterated through the streamlines, we know the length. - streamlines = LazyStreamlines(points, scalars, properties) + streamlines = LazyTractogram(points, scalars, properties) assert_true(streamlines.header.nb_streamlines is None) for streamline in streamlines: pass @@ -299,15 +504,15 @@ def test_lazy_streamlines_len(self): with clear_and_catch_warnings(record=True, modules=[base_format]) as w: # It first checks if number of streamlines is in the header. - streamlines = LazyStreamlines(points, scalars, properties) + streamlines = LazyTractogram(points, scalars, properties) streamlines.header.nb_streamlines = 1234 # This should *not* produce a warning. assert_equal(len(streamlines), 1234) assert_equal(len(w), 0) def test_lazy_streamlines_header(self): - # Empty `LazyStreamlines`, with default header - streamlines = LazyStreamlines() + # Empty `LazyTractogram`, with default header + streamlines = LazyTractogram() assert_true(streamlines.header.nb_streamlines is None) assert_equal(streamlines.header.nb_scalars_per_point, 0) assert_equal(streamlines.header.nb_properties_per_streamline, 0) @@ -318,7 +523,7 @@ def test_lazy_streamlines_header(self): points = lambda: (x for x in self.points) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - streamlines = LazyStreamlines(points) + streamlines = LazyTractogram(points) header = streamlines.header assert_equal(header.nb_scalars_per_point, 0) diff --git a/nibabel/streamlines/tests/test_header.py b/nibabel/streamlines/tests/test_header.py index a36a257818..398195f615 100644 --- a/nibabel/streamlines/tests/test_header.py +++ b/nibabel/streamlines/tests/test_header.py @@ -3,11 +3,11 @@ from nose.tools import assert_equal, assert_true from numpy.testing import assert_array_equal -from nibabel.streamlines.header import StreamlinesHeader +from nibabel.streamlines.header import TractogramHeader def test_streamlines_header(): - header = StreamlinesHeader() + header = TractogramHeader() assert_true(header.nb_streamlines is None) assert_true(header.nb_scalars_per_point is None) assert_true(header.nb_properties_per_streamline is None) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index e6ab8cab17..1e94b79911 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -12,7 +12,7 @@ from nibabel.testing import assert_arrays_equal from nose.tools import assert_equal, assert_raises, assert_true, assert_false -from ..base_format import Streamlines, LazyStreamlines +from ..base_format import Tractogram, LazyTractogram from ..base_format import HeaderError, UsageWarning from ..header import Field from .. import trk @@ -147,7 +147,7 @@ def test_load_empty_file(self): streamlines = nib.streamlines.load(empty_filename, ref=self.to_world_space, lazy_load=False) - assert_true(type(streamlines), Streamlines) + assert_true(type(streamlines), Tractogram) check_streamlines(streamlines, 0, [], [], []) def test_load_simple_file(self): @@ -155,7 +155,7 @@ def test_load_simple_file(self): streamlines = nib.streamlines.load(simple_filename, ref=self.to_world_space, lazy_load=False) - assert_true(type(streamlines), Streamlines) + assert_true(type(streamlines), Tractogram) check_streamlines(streamlines, self.nb_streamlines, self.points, [], []) @@ -163,7 +163,7 @@ def test_load_simple_file(self): streamlines = nib.streamlines.load(simple_filename, ref=self.to_world_space, lazy_load=True) - assert_true(type(streamlines), LazyStreamlines) + assert_true(type(streamlines), LazyTractogram) check_streamlines(streamlines, self.nb_streamlines, self.points, [], []) @@ -182,7 +182,7 @@ def test_load_complex_file(self): streamlines = nib.streamlines.load(complex_filename, ref=self.to_world_space, lazy_load=False) - assert_true(type(streamlines), Streamlines) + assert_true(type(streamlines), Tractogram) check_streamlines(streamlines, self.nb_streamlines, self.points, scalars, properties) @@ -190,12 +190,12 @@ def test_load_complex_file(self): streamlines = nib.streamlines.load(complex_filename, ref=self.to_world_space, lazy_load=True) - assert_true(type(streamlines), LazyStreamlines) + assert_true(type(streamlines), LazyTractogram) check_streamlines(streamlines, self.nb_streamlines, self.points, scalars, properties) def test_save_simple_file(self): - streamlines = Streamlines(self.points) + streamlines = Tractogram(self.points) for ext in nib.streamlines.FORMATS.keys(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: nib.streamlines.save(streamlines, f.name) @@ -204,7 +204,7 @@ def test_save_simple_file(self): self.points, [], []) def test_save_complex_file(self): - streamlines = Streamlines(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) + streamlines = Tractogram(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: with clear_and_catch_warnings(record=True, modules=[trk]) as w: diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index e29b288dc5..6013fe5507 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -9,7 +9,7 @@ from nose.tools import assert_equal, assert_raises, assert_true from .. import base_format -from ..base_format import Streamlines, LazyStreamlines +from ..base_format import Tractogram, LazyTractogram from ..base_format import DataError, HeaderError, HeaderWarning, UsageWarning from .. import trk @@ -128,7 +128,7 @@ def test_load_file_with_wrong_information(self): assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) def test_write_simple_file(self): - streamlines = Streamlines(self.points) + streamlines = Tractogram(self.points) trk_file = BytesIO() TrkFile.save(streamlines, trk_file) @@ -146,7 +146,7 @@ def test_write_simple_file(self): def test_write_complex_file(self): # With scalars - streamlines = Streamlines(self.points, scalars=self.colors) + streamlines = Tractogram(self.points, scalars=self.colors) trk_file = BytesIO() TrkFile.save(streamlines, trk_file) @@ -158,7 +158,7 @@ def test_write_complex_file(self): self.points, self.colors, []) # With properties - streamlines = Streamlines(self.points, properties=self.mean_curvature_torsion) + streamlines = Tractogram(self.points, properties=self.mean_curvature_torsion) trk_file = BytesIO() TrkFile.save(streamlines, trk_file) @@ -169,7 +169,7 @@ def test_write_complex_file(self): self.points, [], self.mean_curvature_torsion) # With scalars and properties - streamlines = Streamlines(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) + streamlines = Tractogram(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) trk_file = BytesIO() TrkFile.save(streamlines, trk_file) @@ -191,14 +191,14 @@ def test_write_erroneous_file(self): [(0, 1, 0)], [(0, 0, 1)]] - streamlines = Streamlines(self.points, scalars) + streamlines = Tractogram(self.points, scalars) assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) # No scalars for every streamlines scalars = [[(1, 0, 0)]*1, [(0, 1, 0)]*2] - streamlines = Streamlines(self.points, scalars) + streamlines = Tractogram(self.points, scalars) assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) # # Unit test moved to test_base_format.py @@ -207,7 +207,7 @@ def test_write_erroneous_file(self): # [(0, 1, 0), (0, 1)], # [(0, 0, 1)]*5] - #streamlines = Streamlines(self.points, scalars) + #streamlines = Tractogram(self.points, scalars) #assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) # # Unit test moved to test_base_format.py @@ -216,7 +216,7 @@ def test_write_erroneous_file(self): # [(0, 1)]*2, # [(0, 0, 1)]*5] - # streamlines = Streamlines(self.points, scalars) + # streamlines = Tractogram(self.points, scalars) # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) # # Unit test moved to test_base_format.py @@ -224,14 +224,14 @@ def test_write_erroneous_file(self): # properties = [np.array([1.11, 1.22], dtype="f4"), # np.array([2.11], dtype="f4"), # np.array([3.11, 3.22], dtype="f4")] - # streamlines = Streamlines(self.points, properties=properties) + # streamlines = Tractogram(self.points, properties=properties) # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) # # Unit test moved to test_base_format.py # No properties for every streamlines properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11, 2.22], dtype="f4")] - streamlines = Streamlines(self.points, properties=properties) + streamlines = Tractogram(self.points, properties=properties) assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) def test_write_file_lazy_streamlines(self): @@ -239,7 +239,7 @@ def test_write_file_lazy_streamlines(self): scalars = lambda: (scalar for scalar in self.colors) properties = lambda: (prop for prop in self.mean_curvature_torsion) - streamlines = LazyStreamlines(points, scalars, properties) + streamlines = LazyTractogram(points, scalars, properties) # No need to manually set `nb_streamlines` in the header since we count # them as writing. #streamlines.header.nb_streamlines = self.nb_streamlines diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 70df0385d0..53087ddad5 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -13,9 +13,9 @@ from nibabel.volumeutils import (native_code, swapped_code) from nibabel.streamlines.base_format import CompactList -from nibabel.streamlines.base_format import StreamlinesFile +from nibabel.streamlines.base_format import TractogramFile from nibabel.streamlines.base_format import DataError, HeaderError, HeaderWarning -from nibabel.streamlines.base_format import Streamlines, LazyStreamlines +from nibabel.streamlines.base_format import Tractogram, LazyTractogram from nibabel.streamlines.header import Field from nibabel.streamlines.utils import get_affine_from_reference @@ -140,15 +140,15 @@ def __init__(self, fileobj): # Keep the file position where the data begin. self.offset_data = f.tell() - if f.name is not None and self.header[Field.NB_STREAMLINES] > 0: - filesize = os.path.getsize(f.name) - self.offset_data - # Remove properties - filesize -= self.header[Field.NB_STREAMLINES] * self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * 4. - # Remove the points count at the beginning of each streamline. - filesize -= self.header[Field.NB_STREAMLINES] * 4. - # Get nb points. - nb_points = filesize / ((3 + self.header[Field.NB_SCALARS_PER_POINT]) * 4.) - self.header[Field.NB_POINTS] = int(nb_points) + # if f.name is not None and self.header[Field.NB_STREAMLINES] > 0: + # filesize = os.path.getsize(f.name) - self.offset_data + # # Remove properties + # filesize -= self.header[Field.NB_STREAMLINES] * self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * 4. + # # Remove the points count at the beginning of each streamline. + # filesize -= self.header[Field.NB_STREAMLINES] * 4. + # # Get nb points. + # nb_points = filesize / ((3 + self.header[Field.NB_SCALARS_PER_POINT]) * 4.) + # self.header[Field.NB_POINTS] = int(nb_points) def __iter__(self): i4_dtype = np.dtype(self.endianness + "i4") @@ -162,10 +162,10 @@ def __iter__(self): nb_pts_and_scalars = int(3 + self.header[Field.NB_SCALARS_PER_POINT]) pts_and_scalars_size = int(nb_pts_and_scalars * f4_dtype.itemsize) - slice_pts_and_scalars = lambda data: (data, []) - if self.header[Field.NB_SCALARS_PER_POINT] > 0: - # This is faster than `np.split` - slice_pts_and_scalars = lambda data: (data[:, :3], data[:, 3:]) + #slice_pts_and_scalars = lambda data: (data, [[]]) + #if self.header[Field.NB_SCALARS_PER_POINT] > 0: + # This is faster than `np.split` + slice_pts_and_scalars = lambda data: (data[:, :3], data[:, 3:]) # Using np.fromfile would be faster, but does not support StringIO read_pts_and_scalars = lambda nb_pts: slice_pts_and_scalars(np.ndarray(shape=(nb_pts, nb_pts_and_scalars), @@ -202,14 +202,6 @@ def __iter__(self): pts, scalars = read_pts_and_scalars(nb_pts) properties = read_properties() - # TRK's streamlines are in 'voxelmm' space, we send them to voxel space. - pts = pts / self.header[Field.VOXEL_SIZES] - # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas streamlines returned assume (0,0,0) to be the - # center of the voxel. Thus, streamlines are shifted of half - #a voxel. - pts -= np.array(self.header[Field.VOXEL_SIZES])/2. - yield pts, scalars, properties i += 1 @@ -317,7 +309,7 @@ def write(self, streamlines): self.file.write(self.header[0].tostring()) -class TrkFile(StreamlinesFile): +class TrkFile(TractogramFile): ''' Convenience class to encapsulate TRK file format. Note @@ -393,80 +385,96 @@ def load(fileobj, ref=None, lazy_load=False): Returns ------- - streamlines : Streamlines object + streamlines : Tractogram object Returns an object containing streamlines' data and header - information. See `nibabel.Streamlines`. + information. See `nibabel.Tractogram`. Notes ----- - Streamlines are assumed to be in voxel space where coordinate (0,0,0) + Tractogram are assumed to be in voxel space where coordinate (0,0,0) refers to the center of the voxel. ''' trk_reader = TrkReader(fileobj) - # Check if reference space matches one from TRK's header. + # TRK's streamlines are in 'voxelmm' space, we send them to rasmm. affine = trk_reader.header[Field.to_world_space] - if ref is not None: - affine = get_affine_from_reference(ref) - if not np.allclose(affine, trk_reader.header[Field.to_world_space]): - raise ValueError("Reference space provided does not match the " - " one from the TRK file header. Use `ref=None`" - " to use one contained in the TRK file") + affine[range(3), range(3)] /= trk_reader.header[Field.VOXEL_SIZES] - #points = lambda: (x[0] for x in trk_reader) - #scalars = lambda: (x[1] for x in trk_reader) - #properties = lambda: (x[2] for x in trk_reader) - data = lambda: iter(trk_reader) + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas streamlines returned assume (0,0,0) to be the + # center of the voxel. Thus, streamlines are shifted of half + #a voxel. + affine[:-1, -1] -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. if lazy_load: - streamlines = LazyStreamlines.create_from_data(data) + def _apply_transform(trk_reader): + for pts, scals, props in trk_reader: + # TRK's streamlines are in 'voxelmm' space, we send them to voxel space. + pts = pts / trk_reader.header[Field.VOXEL_SIZES] + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas streamlines returned assume (0,0,0) to be the + # center of the voxel. Thus, streamlines are shifted of half + #a voxel. + pts -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. + trk_reader + yield pts, scals, props + + data = lambda: _apply_transform(trk_reader) + streamlines = LazyTractogram.create_from_data(data) # Overwrite scalars and properties if there is none if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: streamlines.scalars = lambda: [] if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: streamlines.properties = lambda: [] - elif Field.NB_POINTS in trk_reader.header: - # 'count' field is provided, we can avoid creating list of numpy - # arrays (more memory efficient). - - nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] - nb_points = trk_reader.header[Field.NB_POINTS] - - points = CompactList(preallocate=(nb_points, 3)) - scalars = CompactList(preallocate=(nb_points, trk_reader.header[Field.NB_SCALARS_PER_POINT])) - properties = np.empty((nb_streamlines, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]), - dtype=np.float32) - - offset = 0 - offsets = [] - lengths = [] - for i, (pts, scals, props) in enumerate(data()): - offsets.append(offset) - lengths.append(len(pts)) - try: - points.data[offset:offset+len(pts)] = pts - except: - from ipdb import set_trace as dbg - dbg() - scalars.data[offset:offset+len(scals)] = scals - properties[i] = props - offset += len(pts) - - points.offsets = offsets - scalars.offsets = offsets - points.lengths = lengths - scalars.lengths = lengths - streamlines = Streamlines(points, scalars, properties) - # Overwrite scalars and properties if there is none - if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: - streamlines.scalars = [] - if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: - streamlines.properties = [] + # elif Field.NB_POINTS in trk_reader.header: + # # 'count' field is provided, we can avoid creating list of numpy + # # arrays (more memory efficient). + + # nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] + # nb_points = trk_reader.header[Field.NB_POINTS] + + # points = CompactList() + # points._data = np.empty((nb_points, 3), dtype=np.float32) + + # scalars = CompactList() + # scalars._data = np.empty((nb_points, trk_reader.header[Field.NB_SCALARS_PER_POINT]), + # dtype=np.float32) + + # properties = CompactList() + # properties._data = np.empty((nb_streamlines, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]), + # dtype=np.float32) + + # offset = 0 + # offsets = [] + # lengths = [] + # for i, (pts, scals, props) in enumerate(trk_reader): + # offsets.append(offset) + # lengths.append(len(pts)) + # points._data[offset:offset+len(pts)] = pts + # scalars._data[offset:offset+len(scals)] = scals + # properties._data[i] = props + # offset += len(pts) + + # points.offsets = offsets + # scalars.offsets = offsets + # points.lengths = lengths + # scalars.lengths = lengths + + # streamlines = Tractogram(points, scalars, properties) + # streamlines.apply_affine(affine) + + # # Overwrite scalars and properties if there is none + # if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: + # streamlines.scalars = [] + # if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: + # streamlines.properties = [] else: - streamlines = Streamlines(*zip(*data())) + streamlines = Tractogram.create_from_generator(trk_reader) + #streamlines = Tractogram(*zip(*trk_reader)) + streamlines.apply_affine(affine) # Overwrite scalars and properties if there is none if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: @@ -502,9 +510,9 @@ def save(streamlines, fileobj, ref=None): Parameters ---------- - streamlines : Streamlines object + streamlines : Tractogram object Object containing streamlines' data and header information. - See 'nibabel.Streamlines'. + See 'nibabel.Tractogram'. fileobj : string or file-like object If string, a filename; otherwise an open file-like object @@ -516,7 +524,7 @@ def save(streamlines, fileobj, ref=None): Notes ----- - Streamlines are assumed to be in voxel space where coordinate (0,0,0) + Tractogram are assumed to be in voxel space where coordinate (0,0,0) refers to the center of the voxel. ''' if ref is not None: From 177d01a067449c3a7d89e7d16468f5605cc20e10 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 30 Oct 2015 23:21:38 -0400 Subject: [PATCH 012/135] Moves some unit tests --- nibabel/streamlines/base_format.py | 11 ++++++++++ nibabel/streamlines/tests/test_base_format.py | 16 ++++++++++++++- nibabel/streamlines/tests/test_trk.py | 20 +++++++++---------- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 95fc2669b5..51ccf31f68 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -276,6 +276,9 @@ def create_from_generator(cls, gen): scals = np.asarray(first_element[1]) props = np.asarray(first_element[2]) + scals_shape = scals.shape + props_shape = props.shape + points._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) scalars._data = np.empty((BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) properties = np.empty((BUFFER_SIZE, props.shape[0]), dtype=props.dtype) @@ -286,6 +289,14 @@ def create_from_generator(cls, gen): scals = np.asarray(scals) props = np.asarray(props) + if scals.shape[1] != scals_shape[1]: + raise ValueError("Number of scalars differs from one" + " point or streamline to another") + + if props.shape != props_shape: + raise ValueError("Number of properties differs from one" + " streamline to another") + end = offset + len(pts) if end >= len(points._data): # Resize is needed (at least `len(pts)` items will be added). diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index b699066c1c..23dbd425b3 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -289,7 +289,20 @@ def test_streamlines_creation_from_arrays(self): for streamline in streamlines: pass - #streamlines = Tractogram(self.points, scalars) + # Inconsistent number of scalars between points + scalars = [[(1, 0, 0)]*1, + [(0, 1, 0), (0, 1)], + [(0, 0, 1)]*5] + + assert_raises(ValueError, Tractogram, self.points, scalars) + + # Unit test moved to test_base_format.py + # Inconsistent number of scalars between streamlines + scalars = [[(1, 0, 0)]*1, + [(0, 1)]*2, + [(0, 0, 1)]*5] + + assert_raises(ValueError, Tractogram, self.points, scalars) def test_streamlines_getter(self): # Tractogram with only points @@ -356,6 +369,7 @@ def test_header(self): repr(streamlines.header) + class TestLazyTractogram(unittest.TestCase): def setUp(self): diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 6013fe5507..0c5d58aadf 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -207,8 +207,8 @@ def test_write_erroneous_file(self): # [(0, 1, 0), (0, 1)], # [(0, 0, 1)]*5] - #streamlines = Tractogram(self.points, scalars) - #assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) + # streamlines = Tractogram(self.points, scalars) + # assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) # # Unit test moved to test_base_format.py # # Inconsistent number of scalars between streamlines @@ -219,15 +219,15 @@ def test_write_erroneous_file(self): # streamlines = Tractogram(self.points, scalars) # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) - # # Unit test moved to test_base_format.py - # # Inconsistent number of properties - # properties = [np.array([1.11, 1.22], dtype="f4"), - # np.array([2.11], dtype="f4"), - # np.array([3.11, 3.22], dtype="f4")] - # streamlines = Tractogram(self.points, properties=properties) - # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + # Unit test moved to test_base_format.py + # Inconsistent number of properties + properties = [np.array([1.11, 1.22], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + streamlines = Tractogram(self.points, properties=properties) + assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) - # # Unit test moved to test_base_format.py + # Unit test moved to test_base_format.py # No properties for every streamlines properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11, 2.22], dtype="f4")] From ee58ae0a79efb79ab3e9835e3a98aa05ea5d819b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 2 Nov 2015 08:30:39 -0500 Subject: [PATCH 013/135] Reduced memory usage. --- nibabel/streamlines/base_format.py | 22 ++++++++---- nibabel/streamlines/tests/test_base_format.py | 18 +++++++++- nibabel/streamlines/trk.py | 36 ++++++++----------- 3 files changed, 47 insertions(+), 29 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 51ccf31f68..060395f49b 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -258,7 +258,7 @@ def __init__(self, points=None, scalars=None, properties=None): @classmethod def create_from_generator(cls, gen): - BUFFER_SIZE = 1000 + BUFFER_SIZE = 1000000 points = CompactList() scalars = CompactList() @@ -301,7 +301,7 @@ def create_from_generator(cls, gen): if end >= len(points._data): # Resize is needed (at least `len(pts)` items will be added). points._data.resize((len(points._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) - scalars._data.resize((len(scalars._data) + len(scalars)+BUFFER_SIZE, scals.shape[1])) + scalars._data.resize((len(scalars._data) + len(scals)+BUFFER_SIZE, scals.shape[1])) points._offsets.append(offset) points._lengths.append(len(pts)) @@ -315,14 +315,24 @@ def create_from_generator(cls, gen): properties[i] = props + # Clear unused memory. + points._data.resize((offset, pts.shape[1])) + + if scals_shape[1] == 0: + # Because resizing an empty ndarray creates memory! + scalars._data = np.empty((offset, scals.shape[1])) + else: + scalars._data.resize((offset, scals.shape[1])) + # Share offsets and lengths between points and scalars. scalars._offsets = points._offsets scalars._lengths = points._lengths - # Clear unused memory. - points._data.resize((offset, pts.shape[1])) - scalars._data.resize((offset, scals.shape[1])) - properties.resize((i+1, props.shape[0])) + if props_shape[0] == 0: + # Because resizing an empty ndarray creates memory! + properties = np.empty((i+1, props.shape[0])) + else: + properties.resize((i+1, props.shape[0])) return cls(points, scalars, properties) diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index 23dbd425b3..1aa650489e 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -330,13 +330,29 @@ def test_streamlines_getter(self): assert_arrays_equal(r_streamlines.scalars, self.colors[::-1]) assert_arrays_equal(r_streamlines.properties, self.mean_curvature_torsion[::-1]) + def test_streamlines_creation_from_generator(self): + # Create `Tractogram` from a generator yielding 3-tuples + gen = (x for x in zip(self.points, self.colors, self.mean_curvature_torsion)) + + streamlines = Tractogram.create_from_generator(gen) + with suppress_warnings(): + assert_equal(len(streamlines), self.nb_streamlines) + + assert_arrays_equal(streamlines.points, self.points) + assert_arrays_equal(streamlines.scalars, self.colors) + assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + + # Check if we can iterate through the streamlines. + for streamline in streamlines: + pass + def test_streamlines_creation_from_coroutines(self): # Points, scalars and properties points = lambda: (x for x in self.points) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - # To create streamlines from coroutines use `LazyTractogram`. + # To create streamlines from multiple coroutines use `LazyTractogram`. assert_raises(TypeError, Tractogram, points, scalars, properties) def test_header(self): diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 53087ddad5..3194218190 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -154,30 +154,12 @@ def __iter__(self): i4_dtype = np.dtype(self.endianness + "i4") f4_dtype = np.dtype(self.endianness + "f4") - #from io import BufferedReader with Opener(self.fileobj) as f: - #f = BufferedReader(f.fobj) start_position = f.tell() nb_pts_and_scalars = int(3 + self.header[Field.NB_SCALARS_PER_POINT]) pts_and_scalars_size = int(nb_pts_and_scalars * f4_dtype.itemsize) - - #slice_pts_and_scalars = lambda data: (data, [[]]) - #if self.header[Field.NB_SCALARS_PER_POINT] > 0: - # This is faster than `np.split` - slice_pts_and_scalars = lambda data: (data[:, :3], data[:, 3:]) - - # Using np.fromfile would be faster, but does not support StringIO - read_pts_and_scalars = lambda nb_pts: slice_pts_and_scalars(np.ndarray(shape=(nb_pts, nb_pts_and_scalars), - dtype=f4_dtype, - buffer=f.read(nb_pts * pts_and_scalars_size))) - properties_size = int(self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize) - read_properties = lambda: [] - if self.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: - read_properties = lambda: np.fromstring(f.read(properties_size), - dtype=f4_dtype, - count=self.header[Field.NB_PROPERTIES_PER_STREAMLINE]) # Set the file position at the beginning of the data. f.seek(self.offset_data, os.SEEK_SET) @@ -188,6 +170,7 @@ def __iter__(self): nb_streamlines = np.inf i = 0 + nb_pts_dtype = i4_dtype.str[:-1] while i < nb_streamlines: nb_pts_str = f.read(i4_dtype.itemsize) @@ -196,13 +179,22 @@ def __iter__(self): break # Read number of points of the next streamline. - nb_pts = struct.unpack(i4_dtype.str[:-1], nb_pts_str)[0] + nb_pts = struct.unpack(nb_pts_dtype, nb_pts_str)[0] # Read streamline's data - pts, scalars = read_pts_and_scalars(nb_pts) - properties = read_properties() + points_and_scalars = np.ndarray(shape=(nb_pts, nb_pts_and_scalars), + dtype=f4_dtype, + buffer=f.read(nb_pts * pts_and_scalars_size)) + + points = points_and_scalars[:, :3] + scalars = points_and_scalars[:, 3:] + + # Read properties + properties = np.ndarray(shape=(self.header[Field.NB_PROPERTIES_PER_STREAMLINE],), + dtype=f4_dtype, + buffer=f.read(properties_size)) - yield pts, scalars, properties + yield points, scalars, properties i += 1 # In case the 'count' field was not provided. From 54a561997c9bba5ef1689dd50c245aaf2535b421 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 2 Nov 2015 11:04:05 -0500 Subject: [PATCH 014/135] Refactored streamlines->tractogram and points->streamlines --- nibabel/streamlines/base_format.py | 220 +++++------ nibabel/streamlines/tests/test_base_format.py | 353 +++++++++--------- nibabel/streamlines/tests/test_streamlines.py | 134 +++---- nibabel/streamlines/tests/test_trk.py | 152 ++++---- nibabel/streamlines/trk.py | 62 +-- 5 files changed, 450 insertions(+), 471 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 060395f49b..42e34d0a37 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -2,6 +2,8 @@ import numpy as np from warnings import warn +from abc import ABCMeta, abstractmethod, abstractproperty + from nibabel.externals.six.moves import zip_longest from nibabel.affines import apply_affine @@ -44,7 +46,7 @@ def __init__(self, iterable=None): if iterable is not None: # Initialize the `CompactList` object from iterable's item. - BUFFER_SIZE = 1000 + BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. offset = 0 for i, e in enumerate(iterable): @@ -190,15 +192,16 @@ def __len__(self): def __repr__(self): return repr(list(self)) + class TractogramItem(object): ''' Class containing information about one streamline. - ``TractogramItem`` objects have three main properties: `points`, `scalars` + ``TractogramItem`` objects have three main properties: `streamline`, `scalars` and ``properties``. Parameters ---------- - points : ndarray of shape (N, 3) + streamline : ndarray of shape (N, 3) Points of this streamline represented as an ndarray of shape (N, 3) where N is the number of points. @@ -211,56 +214,56 @@ class TractogramItem(object): Properties associated with this streamline and represented as an ndarray of shape (P,) where P is the number of properties. ''' - def __init__(self, points, scalars=None, properties=None): - #if scalars is not None and len(points) != len(scalars): - # raise ValueError("First dimension of points and scalars must match.") + def __init__(self, streamline, scalars=None, properties=None): + #if scalars is not None and len(streamline) != len(scalars): + # raise ValueError("First dimension of streamline and scalars must match.") - self.points = np.asarray(points) + self.streamline = np.asarray(streamline) self.scalars = np.asarray([] if scalars is None else scalars) self.properties = np.asarray([] if properties is None else properties) def __iter__(self): - return iter(self.points) + return iter(self.streamline) def __len__(self): - return len(self.points) + return len(self.streamline) class Tractogram(object): ''' Class containing information about streamlines. - Tractogram objects have three main properties: ``points``, ``scalars`` + Tractogram objects have three main properties: ``streamlines``, ``scalars`` and ``properties``. Tractogram objects can be iterate over producing - tuple of ``points``, ``scalars`` and ``properties`` for each streamline. + tuple of ``streamlines``, ``scalars`` and ``properties`` for each streamline. Parameters ---------- - points : list of ndarray of shape (N, 3) - Sequence of T streamlines. One streamline is an ndarray of shape (N, 3) - where N is the number of points in a streamline. - - scalars : list of ndarray of shape (N, M) - Sequence of T ndarrays of shape (N, M) where T is the number of - streamlines defined by ``points``, N is the number of points - for a particular streamline and M is the number of scalars + streamlines : list of ndarray of shape (Nt, 3) + Sequence of T streamlines. One streamline is an ndarray of shape (Nt, 3) + where Nt is the number of points of streamline t. + + scalars : list of ndarray of shape (Nt, M) + Sequence of T ndarrays of shape (Nt, M) where T is the number of + streamlines defined by ``streamlines``, Nt is the number of points + for a particular streamline t and M is the number of scalars associated to each point (excluding the three coordinates). properties : list of ndarray of shape (P,) Sequence of T ndarrays of shape (P,) where T is the number of - streamlines defined by ``points``, P is the number of properties + streamlines defined by ``streamlines``, P is the number of properties associated to each streamline. ''' - def __init__(self, points=None, scalars=None, properties=None): + def __init__(self, streamlines=None, scalars=None, properties=None): self._header = TractogramHeader() - self.points = points + self.streamlines = streamlines self.scalars = scalars self.properties = properties @classmethod def create_from_generator(cls, gen): - BUFFER_SIZE = 1000000 + BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. - points = CompactList() + streamlines = CompactList() scalars = CompactList() properties = np.array([]) @@ -269,7 +272,7 @@ def create_from_generator(cls, gen): first_element = next(gen) gen = itertools.chain([first_element], gen) except StopIteration: - return cls(points, scalars, properties) + return cls(streamlines, scalars, properties) # Allocated some buffer memory. pts = np.asarray(first_element[0]) @@ -279,7 +282,7 @@ def create_from_generator(cls, gen): scals_shape = scals.shape props_shape = props.shape - points._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) + streamlines._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) scalars._data = np.empty((BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) properties = np.empty((BUFFER_SIZE, props.shape[0]), dtype=props.dtype) @@ -298,14 +301,14 @@ def create_from_generator(cls, gen): " streamline to another") end = offset + len(pts) - if end >= len(points._data): + if end >= len(streamlines._data): # Resize is needed (at least `len(pts)` items will be added). - points._data.resize((len(points._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) + streamlines._data.resize((len(streamlines._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) scalars._data.resize((len(scalars._data) + len(scals)+BUFFER_SIZE, scals.shape[1])) - points._offsets.append(offset) - points._lengths.append(len(pts)) - points._data[offset:offset+len(pts)] = pts + streamlines._offsets.append(offset) + streamlines._lengths.append(len(pts)) + streamlines._data[offset:offset+len(pts)] = pts scalars._data[offset:offset+len(scals)] = scals offset += len(pts) @@ -316,7 +319,7 @@ def create_from_generator(cls, gen): properties[i] = props # Clear unused memory. - points._data.resize((offset, pts.shape[1])) + streamlines._data.resize((offset, pts.shape[1])) if scals_shape[1] == 0: # Because resizing an empty ndarray creates memory! @@ -324,9 +327,9 @@ def create_from_generator(cls, gen): else: scalars._data.resize((offset, scals.shape[1])) - # Share offsets and lengths between points and scalars. - scalars._offsets = points._offsets - scalars._lengths = points._lengths + # Share offsets and lengths between streamlines and scalars. + scalars._offsets = streamlines._offsets + scalars._lengths = streamlines._lengths if props_shape[0] == 0: # Because resizing an empty ndarray creates memory! @@ -334,7 +337,7 @@ def create_from_generator(cls, gen): else: properties.resize((i+1, props.shape[0])) - return cls(points, scalars, properties) + return cls(streamlines, scalars, properties) @property @@ -342,16 +345,16 @@ def header(self): return self._header @property - def points(self): - return self._points + def streamlines(self): + return self._streamlines - @points.setter - def points(self, value): - self._points = value + @streamlines.setter + def streamlines(self, value): + self._streamlines = value if not isinstance(value, CompactList): - self._points = CompactList(value) + self._streamlines = CompactList(value) - self.header.nb_streamlines = len(self.points) + self.header.nb_streamlines = len(self.streamlines) @property def scalars(self): @@ -382,11 +385,11 @@ def properties(self, value): self.header.nb_properties_per_streamline = len(self.properties[0]) def __iter__(self): - for data in zip_longest(self.points, self.scalars, self.properties, fillvalue=None): + for data in zip_longest(self.streamlines, self.scalars, self.properties, fillvalue=None): yield TractogramItem(*data) def __getitem__(self, idx): - pts = self.points[idx] + pts = self.streamlines[idx] scalars = [] if len(self.scalars) > 0: scalars = self.scalars[idx] @@ -401,11 +404,11 @@ def __getitem__(self, idx): return TractogramItem(pts, scalars, properties) def __len__(self): - return len(self.points) + return len(self.streamlines) def copy(self): """ Returns a copy of this `Tractogram` object. """ - streamlines = Tractogram(self.points.copy(), self.scalars.copy(), self.properties.copy()) + streamlines = Tractogram(self.streamlines.copy(), self.scalars.copy(), self.properties.copy()) streamlines._header = self.header.copy() return streamlines @@ -419,53 +422,53 @@ def apply_affine(self, affine): affine : 2D array (4,4) Transformation that will be applied on each streamline. """ - if len(self.points) == 0: + if len(self.streamlines) == 0: return - BUFFER_SIZE = 10000 - for i in range(0, len(self.points._data), BUFFER_SIZE): - pts = self.points._data[i:i+BUFFER_SIZE] - self.points._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) + BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. + for i in range(0, len(self.streamlines._data), BUFFER_SIZE): + pts = self.streamlines._data[i:i+BUFFER_SIZE] + self.streamlines._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) class LazyTractogram(Tractogram): ''' Class containing information about streamlines. - Tractogram objects have four main properties: ``header``, ``points``, + Tractogram objects have four main properties: ``header``, ``streamlines``, ``scalars`` and ``properties``. Tractogram objects are iterable and - produce tuple of ``points``, ``scalars`` and ``properties`` for each + produce tuple of ``streamlines``, ``scalars`` and ``properties`` for each streamline. Parameters ---------- - points_func : coroutine ouputting (N,3) array-like (optional) - Function yielding streamlines' points. One streamline's points is - an array-like of shape (N,3) where N is the number of points in a - streamline. - - scalars_func : coroutine ouputting (N,M) array-like (optional) - Function yielding streamlines' scalars. One streamline's scalars is - an array-like of shape (N,M) where N is the number of points for a - particular streamline and M is the number of scalars associated to - each point (excluding the three coordinates). + streamlines_func : coroutine ouputting (Nt,3) array-like (optional) + Function yielding streamlines. One streamline is + an ndarray of shape (Nt,3) where Nt is the number of points of + streamline t. + + scalars_func : coroutine ouputting (Nt,M) array-like (optional) + Function yielding scalars for a particular streamline t. The scalars + are represented as an ndarray of shape (Nt,M) where Nt is the number + of points of that streamline t and M is the number of scalars + associated to each point (excluding the three coordinates). properties_func : coroutine ouputting (P,) array-like (optional) - Function yielding streamlines' properties. One streamline's properties - is an array-like of shape (P,) where P is the number of properties - associated to each streamline. + Function yielding properties for a particular streamline t. The + properties are represented as an ndarray of shape (P,) where P is + the number of properties associated to each streamline. getitem_func : function `idx -> 3-tuples` (optional) - Function returning streamlines (one or a list of 3-tuples) given - an index or a slice (i.e. the __getitem__ function to use). + Function returning a subset of the tractogram given an index or a + slice (i.e. the __getitem__ function to use). Notes ----- If provided, ``scalars`` and ``properties`` must yield the same number of - values as ``points``. + values as ``streamlines``. ''' - def __init__(self, points_func=lambda:[], scalars_func=lambda: [], properties_func=lambda: [], getitem_func=None): - super(LazyTractogram, self).__init__(points_func, scalars_func, properties_func) - self._data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[]) + def __init__(self, streamlines_func=lambda:[], scalars_func=lambda: [], properties_func=lambda: [], getitem_func=None): + super(LazyTractogram, self).__init__(streamlines_func, scalars_func, properties_func) + self._data = lambda: zip_longest(self.streamlines, self.scalars, self.properties, fillvalue=[]) self._getitem = getitem_func @classmethod @@ -475,12 +478,11 @@ def create_from_data(cls, data_func): Parameters ---------- data_func : coroutine ouputting tuple (optional) - Function yielding 3-tuples, (streamline's points, streamline's - scalars, streamline's properties). A streamline's points is an - array-like of shape (N,3), a streamline's scalars is an array-like - of shape (N,M) and streamline's properties is an array-like of - shape (P,) where N is the number of points for a particular - streamline, M is the number of scalars associated to each point + Function yielding 3-tuples, (streamlines, scalars, properties). + Streamlines are represented as an ndarray of shape (Nt,3), scalars + as an ndarray of shape (Nt,M) and properties as an ndarray of shape + (P,) where Nt is the number of points for a particular + streamline t, M is the number of scalars associated to each point (excluding the three coordinates) and P is the number of properties associated to each streamline. ''' @@ -489,21 +491,21 @@ def create_from_data(cls, data_func): lazy_streamlines = cls() lazy_streamlines._data = data_func - lazy_streamlines.points = lambda: (x[0] for x in data_func()) + lazy_streamlines.streamlines = lambda: (x[0] for x in data_func()) lazy_streamlines.scalars = lambda: (x[1] for x in data_func()) lazy_streamlines.properties = lambda: (x[2] for x in data_func()) return lazy_streamlines @property - def points(self): - return self._points() + def streamlines(self): + return self._streamlines() - @points.setter - def points(self, value): + @streamlines.setter + def streamlines(self, value): if not callable(value): - raise TypeError("`points` must be a coroutine.") + raise TypeError("`streamlines` must be a coroutine.") - self._points = value + self._streamlines = value @property def scalars(self): @@ -564,12 +566,12 @@ def __len__(self): def copy(self): """ Returns a copy of this `LazyTractogram` object. """ - streamlines = LazyTractogram(self._points, self._scalars, self._properties) + streamlines = LazyTractogram(self._streamlines, self._scalars, self._properties) streamlines._header = self.header.copy() return streamlines def transform(self, affine): - """ Applies an affine transformation on the points of each streamline. + """ Applies an affine transformation on the streamlines. Parameters ---------- @@ -583,19 +585,29 @@ def transform(self, affine): """ return super(LazyTractogram, self).transform(affine, lazy=True) - def to_world_space(self): - """ Sends the streamlines back into world space. - Returns - ------- - streamlines : `LazyTractogram` object - Tractogram living in world space. - """ - return super(LazyTractogram, self).to_world_space(lazy=True) +class TractogramFile(object): + ''' Convenience class to encapsulate streamlines file format. ''' + __metaclass__ = ABCMeta + def __init__(self, tractogram): + self.tractogram = tractogram -class TractogramFile: - ''' Convenience class to encapsulate streamlines file format. ''' + @property + def streamlines(self): + return self.tractogram.streamlines + + @property + def scalars(self): + return self.tractogram.scalars + + @property + def properties(self): + return self.tractogram.properties + + @property + def header(self): + return self.tractogram.header @classmethod def get_magic_number(cls): @@ -692,15 +704,3 @@ def pretty_print(streamlines): Header information relevant to the streamlines file format. ''' raise NotImplementedError() - - -# class DynamicTractogramFile(TractogramFile): -# ''' Convenience class to encapsulate streamlines file format -# that supports appending streamlines to an existing file. -# ''' - -# def append(self, streamlines): -# raise NotImplementedError() - -# def __iadd__(self, streamlines): -# return self.append(streamlines) diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index 1aa650489e..99d8c8ea90 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -208,28 +208,28 @@ class TestTractogramItem(unittest.TestCase): def test_creating_tractogram_item(self): rng = np.random.RandomState(42) - points = rng.rand(rng.randint(10, 50), 3) - scalars = rng.rand(len(points), 5) + streamline = rng.rand(rng.randint(10, 50), 3) + scalars = rng.rand(len(streamline), 5) properties = rng.rand(42) # Create a streamline with only points - s = TractogramItem(points) - assert_equal(len(s), len(points)) - assert_array_equal(s.scalars, []) - assert_array_equal(s.properties, []) + t = TractogramItem(streamline) + assert_equal(len(t), len(streamline)) + assert_array_equal(t.scalars, []) + assert_array_equal(t.properties, []) # Create a streamline with points, scalars and properties. - s = TractogramItem(points, scalars, properties) - assert_equal(len(s), len(points)) - assert_array_equal(s.points, points) - assert_array_equal(list(s), points) - assert_equal(len(s), len(scalars)) - assert_array_equal(s.scalars, scalars) - assert_array_equal(s.properties, properties) + t = TractogramItem(streamline, scalars, properties) + assert_equal(len(t), len(streamline)) + assert_array_equal(t.streamline, streamline) + assert_array_equal(list(t), streamline) + assert_equal(len(t), len(scalars)) + assert_array_equal(t.scalars, scalars) + assert_array_equal(t.properties, properties) # # Create a streamline with different number of scalars. - # scalars = rng.rand(len(points)+3, 5) - # assert_raises(ValueError, TractogramItem, points, scalars) + # scalars = rng.rand(len(streamline)+3, 5) + # assert_raises(ValueError, TractogramItem, streamline, scalars) class TestTractogram(unittest.TestCase): @@ -239,9 +239,9 @@ def setUp(self): self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") - self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] + self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), np.array([(0, 1, 0)]*2, dtype="f4"), @@ -251,139 +251,138 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - self.nb_streamlines = len(self.points) + self.nb_tractogram = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) - def test_streamlines_creation_from_arrays(self): + def test_tractogram_creation_from_arrays(self): # Empty - streamlines = Tractogram() - assert_equal(len(streamlines), 0) - assert_arrays_equal(streamlines.points, []) - assert_arrays_equal(streamlines.scalars, []) - assert_arrays_equal(streamlines.properties, []) - - # Check if we can iterate through the streamlines. - for streamline in streamlines: + tractogram = Tractogram() + assert_equal(len(tractogram), 0) + assert_arrays_equal(tractogram.streamlines, []) + assert_arrays_equal(tractogram.scalars, []) + assert_arrays_equal(tractogram.properties, []) + + # Check if we can iterate through the tractogram. + for streamline in tractogram: pass - # Only points - streamlines = Tractogram(points=self.points) - assert_equal(len(streamlines), len(self.points)) - assert_arrays_equal(streamlines.points, self.points) - assert_arrays_equal(streamlines.scalars, []) - assert_arrays_equal(streamlines.properties, []) + # Only streamlines + tractogram = Tractogram(streamlines=self.streamlines) + assert_equal(len(tractogram), len(self.streamlines)) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.scalars, []) + assert_arrays_equal(tractogram.properties, []) - # Check if we can iterate through the streamlines. - for streamline in streamlines: + # Check if we can iterate through the tractogram. + for streamline in tractogram: pass # Points, scalars and properties - streamlines = Tractogram(self.points, self.colors, self.mean_curvature_torsion) - assert_equal(len(streamlines), len(self.points)) - assert_arrays_equal(streamlines.points, self.points) - assert_arrays_equal(streamlines.scalars, self.colors) - assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) - - # Check if we can iterate through the streamlines. - for streamline in streamlines: + tractogram = Tractogram(self.streamlines, self.colors, self.mean_curvature_torsion) + assert_equal(len(tractogram), len(self.streamlines)) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.scalars, self.colors) + assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) + + # Check if we can iterate through the tractogram. + for streamline in tractogram: pass - # Inconsistent number of scalars between points + # Inconsistent number of scalars between streamlines scalars = [[(1, 0, 0)]*1, [(0, 1, 0), (0, 1)], [(0, 0, 1)]*5] - assert_raises(ValueError, Tractogram, self.points, scalars) + assert_raises(ValueError, Tractogram, self.streamlines, scalars) # Unit test moved to test_base_format.py - # Inconsistent number of scalars between streamlines + # Inconsistent number of scalars between tractogram scalars = [[(1, 0, 0)]*1, [(0, 1)]*2, [(0, 0, 1)]*5] - assert_raises(ValueError, Tractogram, self.points, scalars) + assert_raises(ValueError, Tractogram, self.streamlines, scalars) - def test_streamlines_getter(self): - # Tractogram with only points - streamlines = Tractogram(points=self.points) + def test_tractogram_getter(self): + # Tractogram with only streamlines + tractogram = Tractogram(streamlines=self.streamlines) - selected_streamlines = streamlines[::2] - assert_equal(len(selected_streamlines), (len(self.points)+1)//2) + selected_tractogram = tractogram[::2] + assert_equal(len(selected_tractogram), (len(self.streamlines)+1)//2) - assert_arrays_equal(selected_streamlines.points, self.points[::2]) - assert_equal(sum(map(len, selected_streamlines.scalars)), 0) - assert_equal(sum(map(len, selected_streamlines.properties)), 0) + assert_arrays_equal(selected_tractogram.streamlines, self.streamlines[::2]) + assert_equal(sum(map(len, selected_tractogram.scalars)), 0) + assert_equal(sum(map(len, selected_tractogram.properties)), 0) - # Tractogram with points, scalars and properties - streamlines = Tractogram(self.points, self.colors, self.mean_curvature_torsion) + # Tractogram with streamlines, scalars and properties + tractogram = Tractogram(self.streamlines, self.colors, self.mean_curvature_torsion) - # Retrieve streamlines by their index - for i, streamline in enumerate(streamlines): - assert_array_equal(streamline.points, streamlines[i].points) - assert_array_equal(streamline.scalars, streamlines[i].scalars) - assert_array_equal(streamline.properties, streamlines[i].properties) + # Retrieve tractogram by their index + for i, t in enumerate(tractogram): + assert_array_equal(t.streamline, tractogram[i].streamline) + assert_array_equal(t.scalars, tractogram[i].scalars) + assert_array_equal(t.properties, tractogram[i].properties) # Use slicing - r_streamlines = streamlines[::-1] - assert_arrays_equal(r_streamlines.points, self.points[::-1]) - assert_arrays_equal(r_streamlines.scalars, self.colors[::-1]) - assert_arrays_equal(r_streamlines.properties, self.mean_curvature_torsion[::-1]) + r_tractogram = tractogram[::-1] + assert_arrays_equal(r_tractogram.streamlines, self.streamlines[::-1]) + assert_arrays_equal(r_tractogram.scalars, self.colors[::-1]) + assert_arrays_equal(r_tractogram.properties, self.mean_curvature_torsion[::-1]) - def test_streamlines_creation_from_generator(self): + def test_tractogram_creation_from_generator(self): # Create `Tractogram` from a generator yielding 3-tuples - gen = (x for x in zip(self.points, self.colors, self.mean_curvature_torsion)) + gen = (x for x in zip(self.streamlines, self.colors, self.mean_curvature_torsion)) - streamlines = Tractogram.create_from_generator(gen) + tractogram = Tractogram.create_from_generator(gen) with suppress_warnings(): - assert_equal(len(streamlines), self.nb_streamlines) + assert_equal(len(tractogram), self.nb_tractogram) - assert_arrays_equal(streamlines.points, self.points) - assert_arrays_equal(streamlines.scalars, self.colors) - assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.scalars, self.colors) + assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) - # Check if we can iterate through the streamlines. - for streamline in streamlines: + # Check if we can iterate through the tractogram. + for streamline in tractogram: pass - def test_streamlines_creation_from_coroutines(self): + def test_tractogram_creation_from_coroutines(self): # Points, scalars and properties - points = lambda: (x for x in self.points) + streamlines = lambda: (x for x in self.streamlines) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - # To create streamlines from multiple coroutines use `LazyTractogram`. - assert_raises(TypeError, Tractogram, points, scalars, properties) + # To create tractogram from multiple coroutines use `LazyTractogram`. + assert_raises(TypeError, Tractogram, streamlines, scalars, properties) def test_header(self): # Empty Tractogram, with default header - streamlines = Tractogram() - assert_equal(streamlines.header.nb_streamlines, 0) - assert_equal(streamlines.header.nb_scalars_per_point, 0) - assert_equal(streamlines.header.nb_properties_per_streamline, 0) - assert_array_equal(streamlines.header.voxel_sizes, (1, 1, 1)) - assert_array_equal(streamlines.header.to_world_space, np.eye(4)) - assert_equal(streamlines.header.extra, {}) + tractogram = Tractogram() + assert_equal(tractogram.header.nb_streamlines, 0) + assert_equal(tractogram.header.nb_scalars_per_point, 0) + assert_equal(tractogram.header.nb_properties_per_streamline, 0) + assert_array_equal(tractogram.header.voxel_sizes, (1, 1, 1)) + assert_array_equal(tractogram.header.to_world_space, np.eye(4)) + assert_equal(tractogram.header.extra, {}) - streamlines = Tractogram(self.points, self.colors, self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, self.colors, self.mean_curvature_torsion) - assert_equal(streamlines.header.nb_streamlines, len(self.points)) - assert_equal(streamlines.header.nb_scalars_per_point, self.colors[0].shape[1]) - assert_equal(streamlines.header.nb_properties_per_streamline, self.mean_curvature_torsion[0].shape[0]) + assert_equal(tractogram.header.nb_streamlines, len(self.streamlines)) + assert_equal(tractogram.header.nb_scalars_per_point, self.colors[0].shape[1]) + assert_equal(tractogram.header.nb_properties_per_streamline, self.mean_curvature_torsion[0].shape[0]) # Modifying voxel_sizes should be reflected in to_world_space - streamlines.header.voxel_sizes = (2, 3, 4) - assert_array_equal(streamlines.header.voxel_sizes, (2, 3, 4)) - assert_array_equal(np.diag(streamlines.header.to_world_space), (2, 3, 4, 1)) + tractogram.header.voxel_sizes = (2, 3, 4) + assert_array_equal(tractogram.header.voxel_sizes, (2, 3, 4)) + assert_array_equal(np.diag(tractogram.header.to_world_space), (2, 3, 4, 1)) # Modifying scaling of to_world_space should be reflected in voxel_sizes - streamlines.header.to_world_space = np.diag([4, 3, 2, 1]) - assert_array_equal(streamlines.header.voxel_sizes, (4, 3, 2)) - assert_array_equal(streamlines.header.to_world_space, np.diag([4, 3, 2, 1])) + tractogram.header.to_world_space = np.diag([4, 3, 2, 1]) + assert_array_equal(tractogram.header.voxel_sizes, (4, 3, 2)) + assert_array_equal(tractogram.header.to_world_space, np.diag([4, 3, 2, 1])) # Test that we can run __repr__ without error. - repr(streamlines.header) - + repr(tractogram.header) class TestLazyTractogram(unittest.TestCase): @@ -393,9 +392,9 @@ def setUp(self): self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") - self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] + self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), np.array([(0, 1, 0)]*2, dtype="f4"), @@ -405,99 +404,99 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - self.nb_streamlines = len(self.points) + self.nb_streamlines = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) - def test_lazy_streamlines_creation(self): - # To create streamlines from arrays use `Tractogram`. - assert_raises(TypeError, LazyTractogram, self.points) + def test_lazy_tractogram_creation(self): + # To create tractogram from arrays use `Tractogram`. + assert_raises(TypeError, LazyTractogram, self.streamlines) # Points, scalars and properties - points = (x for x in self.points) + streamlines = (x for x in self.streamlines) scalars = (x for x in self.colors) properties = (x for x in self.mean_curvature_torsion) # Creating LazyTractogram from generators is not allowed as # generators get exhausted and are not reusable unline coroutines. - assert_raises(TypeError, LazyTractogram, points) - assert_raises(TypeError, LazyTractogram, self.points, scalars) + assert_raises(TypeError, LazyTractogram, streamlines) + assert_raises(TypeError, LazyTractogram, self.streamlines, scalars) assert_raises(TypeError, LazyTractogram, properties_func=properties) # Empty `LazyTractogram` - streamlines = LazyTractogram() + tractogram = LazyTractogram() with suppress_warnings(): - assert_equal(len(streamlines), 0) - assert_arrays_equal(streamlines.points, []) - assert_arrays_equal(streamlines.scalars, []) - assert_arrays_equal(streamlines.properties, []) + assert_equal(len(tractogram), 0) + assert_arrays_equal(tractogram.streamlines, []) + assert_arrays_equal(tractogram.scalars, []) + assert_arrays_equal(tractogram.properties, []) - # Check if we can iterate through the streamlines. - for streamline in streamlines: + # Check if we can iterate through the tractogram. + for streamline in tractogram: pass # Points, scalars and properties - points = lambda: (x for x in self.points) + streamlines = lambda: (x for x in self.streamlines) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - streamlines = LazyTractogram(points, scalars, properties) + tractogram = LazyTractogram(streamlines, scalars, properties) with suppress_warnings(): - assert_equal(len(streamlines), self.nb_streamlines) + assert_equal(len(tractogram), self.nb_streamlines) # Coroutines get re-called and creates new iterators. - assert_arrays_equal(streamlines.points, self.points) - assert_arrays_equal(streamlines.scalars, self.colors) - assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) - assert_arrays_equal(streamlines.points, self.points) - assert_arrays_equal(streamlines.scalars, self.colors) - assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.scalars, self.colors) + assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.scalars, self.colors) + assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) # Create `LazyTractogram` from a coroutine yielding 3-tuples - data = lambda: (x for x in zip(self.points, self.colors, self.mean_curvature_torsion)) + data = lambda: (x for x in zip(self.streamlines, self.colors, self.mean_curvature_torsion)) - streamlines = LazyTractogram.create_from_data(data) + tractogram = LazyTractogram.create_from_data(data) with suppress_warnings(): - assert_equal(len(streamlines), self.nb_streamlines) - assert_arrays_equal(streamlines.points, self.points) - assert_arrays_equal(streamlines.scalars, self.colors) - assert_arrays_equal(streamlines.properties, self.mean_curvature_torsion) + assert_equal(len(tractogram), self.nb_streamlines) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.scalars, self.colors) + assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) - # Check if we can iterate through the streamlines. - for streamline in streamlines: + # Check if we can iterate through the tractogram. + for streamline in tractogram: pass - def test_lazy_streamlines_indexing(self): - points = lambda: (x for x in self.points) + def test_lazy_tractogram_indexing(self): + streamlines = lambda: (x for x in self.streamlines) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) # By default, `LazyTractogram` object does not support indexing. - streamlines = LazyTractogram(points, scalars, properties) - assert_raises(AttributeError, streamlines.__getitem__, 0) + tractogram = LazyTractogram(streamlines, scalars, properties) + assert_raises(AttributeError, tractogram.__getitem__, 0) # Create a `LazyTractogram` object with indexing support. def getitem_without_properties(idx): if isinstance(idx, int) or isinstance(idx, np.integer): - return self.points[idx], self.colors[idx] + return self.streamlines[idx], self.colors[idx] - return list(zip(self.points[idx], self.colors[idx])) + return list(zip(self.streamlines[idx], self.colors[idx])) - streamlines = LazyTractogram(points, scalars, properties, getitem_without_properties) - points, scalars = streamlines[0] - assert_array_equal(points, self.points[0]) + tractogram = LazyTractogram(streamlines, scalars, properties, getitem_without_properties) + streamlines, scalars = tractogram[0] + assert_array_equal(streamlines, self.streamlines[0]) assert_array_equal(scalars, self.colors[0]) - points, scalars = zip(*streamlines[::-1]) - assert_arrays_equal(points, self.points[::-1]) + streamlines, scalars = zip(*tractogram[::-1]) + assert_arrays_equal(streamlines, self.streamlines[::-1]) assert_arrays_equal(scalars, self.colors[::-1]) - points, scalars = zip(*streamlines[:-1]) - assert_arrays_equal(points, self.points[:-1]) + streamlines, scalars = zip(*tractogram[:-1]) + assert_arrays_equal(streamlines, self.streamlines[:-1]) assert_arrays_equal(scalars, self.colors[:-1]) - def test_lazy_streamlines_len(self): - points = lambda: (x for x in self.points) + def test_lazy_tractogram_len(self): + streamlines = lambda: (x for x in self.streamlines) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) @@ -505,61 +504,61 @@ def test_lazy_streamlines_len(self): warnings.simplefilter("always") # Always trigger warnings. # Calling `len` will create new generators each time. - streamlines = LazyTractogram(points, scalars, properties) + tractogram = LazyTractogram(streamlines, scalars, properties) # This should produce a warning message. - assert_equal(len(streamlines), self.nb_streamlines) + assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 1) - streamlines = LazyTractogram(points, scalars, properties) + tractogram = LazyTractogram(streamlines, scalars, properties) # This should still produce a warning message. - assert_equal(len(streamlines), self.nb_streamlines) + assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 2) assert_true(issubclass(w[-1].category, UsageWarning)) # This should *not* produce a warning. - assert_equal(len(streamlines), self.nb_streamlines) + assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 2) with clear_and_catch_warnings(record=True, modules=[base_format]) as w: - # Once we iterated through the streamlines, we know the length. - streamlines = LazyTractogram(points, scalars, properties) - assert_true(streamlines.header.nb_streamlines is None) - for streamline in streamlines: + # Once we iterated through the tractogram, we know the length. + tractogram = LazyTractogram(streamlines, scalars, properties) + assert_true(tractogram.header.nb_streamlines is None) + for streamline in tractogram: pass - assert_equal(streamlines.header.nb_streamlines, len(self.points)) + assert_equal(tractogram.header.nb_streamlines, len(self.streamlines)) # This should *not* produce a warning. - assert_equal(len(streamlines), len(self.points)) + assert_equal(len(tractogram), len(self.streamlines)) assert_equal(len(w), 0) with clear_and_catch_warnings(record=True, modules=[base_format]) as w: - # It first checks if number of streamlines is in the header. - streamlines = LazyTractogram(points, scalars, properties) - streamlines.header.nb_streamlines = 1234 + # It first checks if number of tractogram is in the header. + tractogram = LazyTractogram(streamlines, scalars, properties) + tractogram.header.nb_streamlines = 1234 # This should *not* produce a warning. - assert_equal(len(streamlines), 1234) + assert_equal(len(tractogram), 1234) assert_equal(len(w), 0) - def test_lazy_streamlines_header(self): + def test_lazy_tractogram_header(self): # Empty `LazyTractogram`, with default header - streamlines = LazyTractogram() - assert_true(streamlines.header.nb_streamlines is None) - assert_equal(streamlines.header.nb_scalars_per_point, 0) - assert_equal(streamlines.header.nb_properties_per_streamline, 0) - assert_array_equal(streamlines.header.voxel_sizes, (1, 1, 1)) - assert_array_equal(streamlines.header.to_world_space, np.eye(4)) - assert_equal(streamlines.header.extra, {}) - - points = lambda: (x for x in self.points) + tractogram = LazyTractogram() + assert_true(tractogram.header.nb_streamlines is None) + assert_equal(tractogram.header.nb_scalars_per_point, 0) + assert_equal(tractogram.header.nb_properties_per_streamline, 0) + assert_array_equal(tractogram.header.voxel_sizes, (1, 1, 1)) + assert_array_equal(tractogram.header.to_world_space, np.eye(4)) + assert_equal(tractogram.header.extra, {}) + + streamlines = lambda: (x for x in self.streamlines) scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - streamlines = LazyTractogram(points) - header = streamlines.header + tractogram = LazyTractogram(streamlines) + header = tractogram.header assert_equal(header.nb_scalars_per_point, 0) - streamlines.scalars = scalars + tractogram.scalars = scalars assert_equal(header.nb_scalars_per_point, self.nb_scalars_per_point) assert_equal(header.nb_properties_per_streamline, 0) - streamlines.properties = properties + tractogram.properties = properties assert_equal(header.nb_properties_per_streamline, self.nb_properties_per_streamline) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 1e94b79911..5cd30c65f6 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -9,7 +9,7 @@ from nibabel.externals.six import BytesIO from nibabel.testing import clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal +from nibabel.testing import assert_arrays_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true, assert_false from ..base_format import Tractogram, LazyTractogram @@ -20,29 +20,19 @@ DATA_PATH = pjoin(os.path.dirname(__file__), 'data') -def isiterable(streamlines): - try: - for _ in streamlines: - pass - except: - return False - - return True - - -def check_streamlines(streamlines, nb_streamlines, points, scalars, properties): +def check_tractogram(tractogram, nb_streamlines, streamlines, scalars, properties): # Check data - assert_equal(len(streamlines), nb_streamlines) - assert_arrays_equal(streamlines.points, points) - assert_arrays_equal(streamlines.scalars, scalars) - assert_arrays_equal(streamlines.properties, properties) - assert_true(isiterable(streamlines)) + assert_equal(len(tractogram), nb_streamlines) + assert_arrays_equal(tractogram.streamlines, streamlines) + assert_arrays_equal(tractogram.scalars, scalars) + assert_arrays_equal(tractogram.properties, properties) + assert_true(isiterable(tractogram)) - assert_equal(streamlines.header.nb_streamlines, nb_streamlines) + assert_equal(tractogram.header.nb_streamlines, nb_streamlines) nb_scalars_per_point = 0 if len(scalars) == 0 else len(scalars[0][0]) nb_properties_per_streamline = 0 if len(properties) == 0 else len(properties[0]) - assert_equal(streamlines.header.nb_scalars_per_point, nb_scalars_per_point) - assert_equal(streamlines.header.nb_properties_per_streamline, nb_properties_per_streamline) + assert_equal(tractogram.header.nb_scalars_per_point, nb_scalars_per_point) + assert_equal(tractogram.header.nb_properties_per_streamline, nb_properties_per_streamline) def test_is_supported(): @@ -52,33 +42,33 @@ def test_is_supported(): assert_false(nib.streamlines.is_supported("")) # Valid file without extension - for streamlines_file in nib.streamlines.FORMATS.values(): + for tractogram_file in nib.streamlines.FORMATS.values(): f = BytesIO() - f.write(streamlines_file.get_magic_number()) + f.write(tractogram_file.get_magic_number()) f.seek(0, os.SEEK_SET) assert_true(nib.streamlines.is_supported(f)) # Wrong extension but right magic number - for streamlines_file in nib.streamlines.FORMATS.values(): + for tractogram_file in nib.streamlines.FORMATS.values(): with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: - f.write(streamlines_file.get_magic_number()) + f.write(tractogram_file.get_magic_number()) f.seek(0, os.SEEK_SET) assert_true(nib.streamlines.is_supported(f)) # Good extension but wrong magic number - for ext, streamlines_file in nib.streamlines.FORMATS.items(): + for ext, tractogram_file in nib.streamlines.FORMATS.items(): with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: f.write(b"pass") f.seek(0, os.SEEK_SET) assert_false(nib.streamlines.is_supported(f)) # Wrong extension, string only - f = "my_streamlines.asd" + f = "my_tractogram.asd" assert_false(nib.streamlines.is_supported(f)) # Good extension, string only - for ext, streamlines_file in nib.streamlines.FORMATS.items(): - f = "my_streamlines" + ext + for ext, tractogram_file in nib.streamlines.FORMATS.items(): + f = "my_tractogram" + ext assert_true(nib.streamlines.is_supported(f)) @@ -89,34 +79,34 @@ def test_detect_format(): assert_equal(nib.streamlines.detect_format(""), None) # Valid file without extension - for streamlines_file in nib.streamlines.FORMATS.values(): + for tractogram_file in nib.streamlines.FORMATS.values(): f = BytesIO() - f.write(streamlines_file.get_magic_number()) + f.write(tractogram_file.get_magic_number()) f.seek(0, os.SEEK_SET) - assert_equal(nib.streamlines.detect_format(f), streamlines_file) + assert_equal(nib.streamlines.detect_format(f), tractogram_file) # Wrong extension but right magic number - for streamlines_file in nib.streamlines.FORMATS.values(): + for tractogram_file in nib.streamlines.FORMATS.values(): with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: - f.write(streamlines_file.get_magic_number()) + f.write(tractogram_file.get_magic_number()) f.seek(0, os.SEEK_SET) - assert_equal(nib.streamlines.detect_format(f), streamlines_file) + assert_equal(nib.streamlines.detect_format(f), tractogram_file) # Good extension but wrong magic number - for ext, streamlines_file in nib.streamlines.FORMATS.items(): + for ext, tractogram_file in nib.streamlines.FORMATS.items(): with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: f.write(b"pass") f.seek(0, os.SEEK_SET) assert_equal(nib.streamlines.detect_format(f), None) # Wrong extension, string only - f = "my_streamlines.asd" + f = "my_tractogram.asd" assert_equal(nib.streamlines.detect_format(f), None) # Good extension, string only - for ext, streamlines_file in nib.streamlines.FORMATS.items(): - f = "my_streamlines" + ext - assert_equal(nib.streamlines.detect_format(f), streamlines_file) + for ext, tractogram_file in nib.streamlines.FORMATS.items(): + f = "my_tractogram" + ext + assert_equal(nib.streamlines.detect_format(f), tractogram_file) class TestLoadSave(unittest.TestCase): @@ -125,9 +115,9 @@ def setUp(self): self.simple_filenames = [pjoin(DATA_PATH, "simple" + ext) for ext in nib.streamlines.FORMATS.keys()] self.complex_filenames = [pjoin(DATA_PATH, "complex" + ext) for ext in nib.streamlines.FORMATS.keys()] - self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] + self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), np.array([(0, 1, 0)]*2, dtype="f4"), @@ -137,35 +127,35 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - self.nb_streamlines = len(self.points) + self.nb_streamlines = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) self.to_world_space = np.eye(4) def test_load_empty_file(self): for empty_filename in self.empty_filenames: - streamlines = nib.streamlines.load(empty_filename, + tractogram = nib.streamlines.load(empty_filename, ref=self.to_world_space, lazy_load=False) - assert_true(type(streamlines), Tractogram) - check_streamlines(streamlines, 0, [], [], []) + assert_true(type(tractogram), Tractogram) + check_tractogram(tractogram, 0, [], [], []) def test_load_simple_file(self): for simple_filename in self.simple_filenames: - streamlines = nib.streamlines.load(simple_filename, + tractogram = nib.streamlines.load(simple_filename, ref=self.to_world_space, lazy_load=False) - assert_true(type(streamlines), Tractogram) - check_streamlines(streamlines, self.nb_streamlines, - self.points, [], []) + assert_true(type(tractogram), Tractogram) + check_tractogram(tractogram, self.nb_streamlines, + self.streamlines, [], []) # Test lazy_load - streamlines = nib.streamlines.load(simple_filename, + tractogram = nib.streamlines.load(simple_filename, ref=self.to_world_space, lazy_load=True) - assert_true(type(streamlines), LazyTractogram) - check_streamlines(streamlines, self.nb_streamlines, - self.points, [], []) + assert_true(type(tractogram), LazyTractogram) + check_tractogram(tractogram, self.nb_streamlines, + self.streamlines, [], []) def test_load_complex_file(self): for complex_filename in self.complex_filenames: @@ -179,36 +169,36 @@ def test_load_complex_file(self): if file_format.can_save_properties(): properties = self.mean_curvature_torsion - streamlines = nib.streamlines.load(complex_filename, + tractogram = nib.streamlines.load(complex_filename, ref=self.to_world_space, lazy_load=False) - assert_true(type(streamlines), Tractogram) - check_streamlines(streamlines, self.nb_streamlines, - self.points, scalars, properties) + assert_true(type(tractogram), Tractogram) + check_tractogram(tractogram, self.nb_streamlines, + self.streamlines, scalars, properties) # Test lazy_load - streamlines = nib.streamlines.load(complex_filename, + tractogram = nib.streamlines.load(complex_filename, ref=self.to_world_space, lazy_load=True) - assert_true(type(streamlines), LazyTractogram) - check_streamlines(streamlines, self.nb_streamlines, - self.points, scalars, properties) + assert_true(type(tractogram), LazyTractogram) + check_tractogram(tractogram, self.nb_streamlines, + self.streamlines, scalars, properties) def test_save_simple_file(self): - streamlines = Tractogram(self.points) + tractogram = Tractogram(self.streamlines) for ext in nib.streamlines.FORMATS.keys(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: - nib.streamlines.save(streamlines, f.name) - loaded_streamlines = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) - check_streamlines(loaded_streamlines, self.nb_streamlines, - self.points, [], []) + nib.streamlines.save(tractogram, f.name) + loaded_tractogram = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) + check_tractogram(loaded_tractogram, self.nb_streamlines, + self.streamlines, [], []) def test_save_complex_file(self): - streamlines = Tractogram(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, scalars=self.colors, properties=self.mean_curvature_torsion) for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: with clear_and_catch_warnings(record=True, modules=[trk]) as w: - nib.streamlines.save(streamlines, f.name) + nib.streamlines.save(tractogram, f.name) # If streamlines format does not support saving scalars or # properties, a warning message should be issued. @@ -224,6 +214,6 @@ def test_save_complex_file(self): if cls.can_save_properties(): properties = self.mean_curvature_torsion - loaded_streamlines = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) - check_streamlines(loaded_streamlines, self.nb_streamlines, - self.points, scalars, properties) + loaded_tractogram = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) + check_tractogram(loaded_tractogram, self.nb_streamlines, + self.streamlines, scalars, properties) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 0c5d58aadf..40720027fc 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -5,7 +5,7 @@ from nibabel.externals.six import BytesIO from nibabel.testing import suppress_warnings, clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal, assert_streamlines_equal +from nibabel.testing import assert_arrays_equal, assert_tractogram_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true from .. import base_format @@ -18,43 +18,33 @@ DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') -def isiterable(streamlines): - try: - for streamline in streamlines: - pass - except: - return False - - return True - - -def check_streamlines(streamlines, nb_streamlines, points, scalars, properties): +def check_tractogram(tractogram, nb_streamlines, streamlines, scalars, properties): # Check data - assert_equal(len(streamlines), nb_streamlines) - assert_arrays_equal(streamlines.points, points) - assert_arrays_equal(streamlines.scalars, scalars) - assert_arrays_equal(streamlines.properties, properties) - assert_true(isiterable(streamlines)) + assert_equal(len(tractogram), nb_streamlines) + assert_arrays_equal(tractogram.streamlines, streamlines) + assert_arrays_equal(tractogram.scalars, scalars) + assert_arrays_equal(tractogram.properties, properties) + assert_true(isiterable(tractogram)) - assert_equal(streamlines.header.nb_streamlines, nb_streamlines) + assert_equal(tractogram.header.nb_streamlines, nb_streamlines) nb_scalars_per_point = 0 if len(scalars) == 0 else len(scalars[0][0]) nb_properties_per_streamline = 0 if len(properties) == 0 else len(properties[0]) - assert_equal(streamlines.header.nb_scalars_per_point, nb_scalars_per_point) - assert_equal(streamlines.header.nb_properties_per_streamline, nb_properties_per_streamline) + assert_equal(tractogram.header.nb_scalars_per_point, nb_scalars_per_point) + assert_equal(tractogram.header.nb_properties_per_streamline, nb_properties_per_streamline) class TestTRK(unittest.TestCase): def setUp(self): self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") - # simple.trk contains only points + # simple.trk contains only streamlines self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") - # complex.trk contains points, scalars and properties + # complex.trk contains streamlines, scalars and properties self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") - self.points = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] + self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), np.array([(0, 1, 0)]*2, dtype="f4"), @@ -64,34 +54,34 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - self.nb_streamlines = len(self.points) + self.nb_streamlines = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) def test_load_empty_file(self): trk = TrkFile.load(self.empty_trk_filename, ref=None, lazy_load=False) - check_streamlines(trk, 0, [], [], []) + check_tractogram(trk, 0, [], [], []) trk = TrkFile.load(self.empty_trk_filename, ref=None, lazy_load=True) # Suppress warning about loading a TRK file in lazy mode with count=0. with suppress_warnings(): - check_streamlines(trk, 0, [], [], []) + check_tractogram(trk, 0, [], [], []) def test_load_simple_file(self): trk = TrkFile.load(self.simple_trk_filename, ref=None, lazy_load=False) - check_streamlines(trk, self.nb_streamlines, self.points, [], []) + check_tractogram(trk, self.nb_streamlines, self.streamlines, [], []) trk = TrkFile.load(self.simple_trk_filename, ref=None, lazy_load=True) - check_streamlines(trk, self.nb_streamlines, self.points, [], []) + check_tractogram(trk, self.nb_streamlines, self.streamlines, [], []) def test_load_complex_file(self): trk = TrkFile.load(self.complex_trk_filename, ref=None, lazy_load=False) - check_streamlines(trk, self.nb_streamlines, - self.points, self.colors, self.mean_curvature_torsion) + check_tractogram(trk, self.nb_streamlines, + self.streamlines, self.colors, self.mean_curvature_torsion) trk = TrkFile.load(self.complex_trk_filename, ref=None, lazy_load=True) - check_streamlines(trk, self.nb_streamlines, - self.points, self.colors, self.mean_curvature_torsion) + check_tractogram(trk, self.nb_streamlines, + self.streamlines, self.colors, self.mean_curvature_torsion) def test_load_file_with_wrong_information(self): trk_file = open(self.simple_trk_filename, 'rb').read() @@ -99,12 +89,12 @@ def test_load_file_with_wrong_information(self): # Simulate a TRK file where `count` was not provided. count = np.array(0, dtype="int32").tostring() new_trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] - streamlines = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) - check_streamlines(streamlines, self.nb_streamlines, self.points, [], []) + tractogram = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) + check_tractogram(tractogram, self.nb_streamlines, self.streamlines, [], []) - streamlines = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) + tractogram = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) with clear_and_catch_warnings(record=True, modules=[base_format]) as w: - check_streamlines(streamlines, self.nb_streamlines, self.points, [], []) + check_tractogram(tractogram, self.nb_streamlines, self.streamlines, [], []) assert_equal(len(w), 1) assert_true(issubclass(w[0].category, UsageWarning)) @@ -128,59 +118,59 @@ def test_load_file_with_wrong_information(self): assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) def test_write_simple_file(self): - streamlines = Tractogram(self.points) + tractogram = Tractogram(self.streamlines) trk_file = BytesIO() - TrkFile.save(streamlines, trk_file) + TrkFile.save(tractogram, trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_streamlines = TrkFile.load(trk_file) - check_streamlines(loaded_streamlines, self.nb_streamlines, - self.points, [], []) + loaded_tractogram = TrkFile.load(trk_file) + check_tractogram(loaded_tractogram, self.nb_streamlines, + self.streamlines, [], []) - loaded_streamlines_orig = TrkFile.load(self.simple_trk_filename) - assert_streamlines_equal(loaded_streamlines, loaded_streamlines_orig) + loaded_tractogram_orig = TrkFile.load(self.simple_trk_filename) + assert_tractogram_equal(loaded_tractogram, loaded_tractogram_orig) trk_file.seek(0, os.SEEK_SET) assert_equal(open(self.simple_trk_filename, 'rb').read(), trk_file.read()) def test_write_complex_file(self): # With scalars - streamlines = Tractogram(self.points, scalars=self.colors) + tractogram = Tractogram(self.streamlines, scalars=self.colors) trk_file = BytesIO() - TrkFile.save(streamlines, trk_file) + TrkFile.save(tractogram, trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_streamlines = TrkFile.load(trk_file, ref=None, lazy_load=False) + loaded_tractogram = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_streamlines(loaded_streamlines, self.nb_streamlines, - self.points, self.colors, []) + check_tractogram(loaded_tractogram, self.nb_streamlines, + self.streamlines, self.colors, []) # With properties - streamlines = Tractogram(self.points, properties=self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, properties=self.mean_curvature_torsion) trk_file = BytesIO() - TrkFile.save(streamlines, trk_file) + TrkFile.save(tractogram, trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_streamlines = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_streamlines(loaded_streamlines, self.nb_streamlines, - self.points, [], self.mean_curvature_torsion) + loaded_tractogram = TrkFile.load(trk_file, ref=None, lazy_load=False) + check_tractogram(loaded_tractogram, self.nb_streamlines, + self.streamlines, [], self.mean_curvature_torsion) # With scalars and properties - streamlines = Tractogram(self.points, scalars=self.colors, properties=self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, scalars=self.colors, properties=self.mean_curvature_torsion) trk_file = BytesIO() - TrkFile.save(streamlines, trk_file) + TrkFile.save(tractogram, trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_streamlines = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_streamlines(loaded_streamlines, self.nb_streamlines, - self.points, self.colors, self.mean_curvature_torsion) + loaded_tractogram = TrkFile.load(trk_file, ref=None, lazy_load=False) + check_tractogram(loaded_tractogram, self.nb_streamlines, + self.streamlines, self.colors, self.mean_curvature_torsion) - loaded_streamlines_orig = TrkFile.load(self.complex_trk_filename) - assert_streamlines_equal(loaded_streamlines, loaded_streamlines_orig) + loaded_tractogram_orig = TrkFile.load(self.complex_trk_filename) + assert_tractogram_equal(loaded_tractogram, loaded_tractogram_orig) trk_file.seek(0, os.SEEK_SET) assert_equal(open(self.complex_trk_filename, 'rb').read(), trk_file.read()) @@ -191,15 +181,15 @@ def test_write_erroneous_file(self): [(0, 1, 0)], [(0, 0, 1)]] - streamlines = Tractogram(self.points, scalars) - assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + tractogram = Tractogram(self.streamlines, scalars) + assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) # No scalars for every streamlines scalars = [[(1, 0, 0)]*1, [(0, 1, 0)]*2] - streamlines = Tractogram(self.points, scalars) - assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + tractogram = Tractogram(self.streamlines, scalars) + assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) # # Unit test moved to test_base_format.py # # Inconsistent number of scalars between points @@ -207,8 +197,8 @@ def test_write_erroneous_file(self): # [(0, 1, 0), (0, 1)], # [(0, 0, 1)]*5] - # streamlines = Tractogram(self.points, scalars) - # assert_raises(ValueError, TrkFile.save, streamlines, BytesIO()) + # tractogram = Tractogram(self.streamlines, scalars) + # assert_raises(ValueError, TrkFile.save, tractogram, BytesIO()) # # Unit test moved to test_base_format.py # # Inconsistent number of scalars between streamlines @@ -216,38 +206,38 @@ def test_write_erroneous_file(self): # [(0, 1)]*2, # [(0, 0, 1)]*5] - # streamlines = Tractogram(self.points, scalars) - # assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + # tractogram = Tractogram(self.streamlines, scalars) + # assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) # Unit test moved to test_base_format.py # Inconsistent number of properties properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - streamlines = Tractogram(self.points, properties=properties) - assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + tractogram = Tractogram(self.streamlines, properties=properties) + assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) # Unit test moved to test_base_format.py # No properties for every streamlines properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11, 2.22], dtype="f4")] - streamlines = Tractogram(self.points, properties=properties) - assert_raises(DataError, TrkFile.save, streamlines, BytesIO()) + tractogram = Tractogram(self.streamlines, properties=properties) + assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) - def test_write_file_lazy_streamlines(self): - points = lambda: (point for point in self.points) + def test_write_file_lazy_tractogram(self): + streamlines = lambda: (point for point in self.streamlines) scalars = lambda: (scalar for scalar in self.colors) properties = lambda: (prop for prop in self.mean_curvature_torsion) - streamlines = LazyTractogram(points, scalars, properties) + tractogram = LazyTractogram(streamlines, scalars, properties) # No need to manually set `nb_streamlines` in the header since we count # them as writing. - #streamlines.header.nb_streamlines = self.nb_streamlines + #tractogram.header.nb_streamlines = self.nb_streamlines trk_file = BytesIO() - TrkFile.save(streamlines, trk_file) + TrkFile.save(tractogram, trk_file) trk_file.seek(0, os.SEEK_SET) trk = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_streamlines(trk, self.nb_streamlines, - self.points, self.colors, self.mean_curvature_torsion) + check_tractogram(trk, self.nb_streamlines, + self.streamlines, self.colors, self.mean_curvature_torsion) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 3194218190..ae913ad67f 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -250,17 +250,17 @@ def __init__(self, fileobj, header): self.beginning = self.file.tell() self.file.write(self.header[0].tostring()) - def write(self, streamlines): + def write(self, tractogram): i4_dtype = np.dtype("i4") f4_dtype = np.dtype("f4") - for s in streamlines: - if len(s.scalars) > 0 and len(s.scalars) != len(s.points): + for t in tractogram: + if len(t.scalars) > 0 and len(t.scalars) != len(t.streamline): raise DataError("Missing scalars for some points!") - points = np.asarray(s.points, dtype=f4_dtype) - scalars = np.asarray(s.scalars, dtype=f4_dtype).reshape((len(points), -1)) - properties = np.asarray(s.properties, dtype=f4_dtype) + points = np.asarray(t.streamline, dtype=f4_dtype) + scalars = np.asarray(t.scalars, dtype=f4_dtype).reshape((len(points), -1)) + properties = np.asarray(t.properties, dtype=f4_dtype) # TRK's streamlines need to be in 'voxelmm' space points = points * self.header[Field.VOXEL_SIZES] @@ -377,8 +377,8 @@ def load(fileobj, ref=None, lazy_load=False): Returns ------- - streamlines : Tractogram object - Returns an object containing streamlines' data and header + tractogram : Tractogram object + Returns an object containing tractogram' data and header information. See `nibabel.Tractogram`. Notes @@ -412,13 +412,13 @@ def _apply_transform(trk_reader): yield pts, scals, props data = lambda: _apply_transform(trk_reader) - streamlines = LazyTractogram.create_from_data(data) + tractogram = LazyTractogram.create_from_data(data) # Overwrite scalars and properties if there is none if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: - streamlines.scalars = lambda: [] + tractogram.scalars = lambda: [] if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: - streamlines.properties = lambda: [] + tractogram.properties = lambda: [] # elif Field.NB_POINTS in trk_reader.header: # # 'count' field is provided, we can avoid creating list of numpy @@ -464,46 +464,46 @@ def _apply_transform(trk_reader): # streamlines.properties = [] else: - streamlines = Tractogram.create_from_generator(trk_reader) - #streamlines = Tractogram(*zip(*trk_reader)) - streamlines.apply_affine(affine) + tractogram = Tractogram.create_from_generator(trk_reader) + #tractogram = Tractogram(*zip(*trk_reader)) + tractogram.apply_affine(affine) # Overwrite scalars and properties if there is none if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: - streamlines.scalars = [] + tractogram.scalars = [] if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: - streamlines.properties = [] + tractogram.properties = [] # Set available common information about streamlines in the header - streamlines.header.to_world_space = affine + tractogram.header.to_world_space = affine # If 'count' field is 0, i.e. not provided, we don't set `nb_streamlines` if trk_reader.header[Field.NB_STREAMLINES] > 0: - streamlines.header.nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] + tractogram.header.nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] # Keep extra information about TRK format - streamlines.header.extra = trk_reader.header + tractogram.header.extra = trk_reader.header ## Perform some integrity checks - #if trk_reader.header[Field.VOXEL_ORDER] != streamlines.header.voxel_order: + #if trk_reader.header[Field.VOXEL_ORDER] != tractogram.header.voxel_order: # raise HeaderError("'voxel_order' does not match the affine.") - #if streamlines.header.voxel_sizes != trk_reader.header[Field.VOXEL_SIZES]: + #if tractogram.header.voxel_sizes != trk_reader.header[Field.VOXEL_SIZES]: # raise HeaderError("'voxel_sizes' does not match the affine.") - #if streamlines.header.nb_scalars_per_point != trk_reader.header[Field.NB_SCALARS_PER_POINT]: + #if tractogram.header.nb_scalars_per_point != trk_reader.header[Field.NB_SCALARS_PER_POINT]: # raise HeaderError("'nb_scalars_per_point' does not match.") - #if streamlines.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: + #if tractogram.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: # raise HeaderError("'nb_properties_per_streamline' does not match.") - return streamlines + return tractogram @staticmethod - def save(streamlines, fileobj, ref=None): - ''' Saves streamlines to a file-like object. + def save(tractogram, fileobj, ref=None): + ''' Saves tractogram to a file-like object. Parameters ---------- - streamlines : Tractogram object - Object containing streamlines' data and header information. + tractogram : Tractogram object + Object containing tractogram' data and header information. See 'nibabel.Tractogram'. fileobj : string or file-like object @@ -520,10 +520,10 @@ def save(streamlines, fileobj, ref=None): refers to the center of the voxel. ''' if ref is not None: - streamlines.header.to_world_space = get_affine_from_reference(ref) + tractogram.header.to_world_space = get_affine_from_reference(ref) - trk_writer = TrkWriter(fileobj, streamlines.header) - trk_writer.write(streamlines) + trk_writer = TrkWriter(fileobj, tractogram.header) + trk_writer.write(tractogram) @staticmethod def pretty_print(fileobj): From 2bd732af5822c0efbdfc8cbd808f35f788003a27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 2 Nov 2015 17:09:19 -0500 Subject: [PATCH 015/135] Refactoring StreamingFile --- nibabel/streamlines/__init__.py | 88 ++++++------- nibabel/streamlines/base_format.py | 66 ++++------ nibabel/streamlines/tests/test_streamlines.py | 66 +++++----- nibabel/streamlines/tests/test_trk.py | 117 ++++++++++-------- nibabel/streamlines/trk.py | 92 +++++++++----- 5 files changed, 224 insertions(+), 205 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index e23757013f..a58f72901c 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,4 +1,4 @@ -from .header import Field +from .header import TractogramHeader from .base_format import Tractogram, LazyTractogram from nibabel.streamlines.trk import TrkFile @@ -36,14 +36,14 @@ def detect_format(fileobj): ---------- fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing - to a streamlines file (and ready to read from the beginning of the + to a tractogram file (and ready to read from the beginning of the header) Returns ------- - streamlines_file : StreamlinesFile object - Object that can be used to manage a streamlines file. - See 'nibabel.streamlines.StreamlinesFile'. + tractogram_file : ``TractogramFile`` class + Returns an instance of a `TractogramFile` class containing data and + metadata of the tractogram contained from `fileobj`. ''' for format in FORMATS.values(): try: @@ -62,76 +62,68 @@ def detect_format(fileobj): return None -def load(fileobj, ref, lazy_load=False): +def load(fileobj, lazy_load=False, ref=None): ''' Loads streamlines from a file-like object in voxel space. Parameters ---------- fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the beginning - of the streamlines file's header). - - ref : filename | `Nifti1Image` object | 2D array (4,4) - Reference space where streamlines will live in `fileobj`. + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the beginning + of the streamlines file's header). lazy_load : boolean (optional) Load streamlines in a lazy manner i.e. they will not be kept in memory. + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines will live in `fileobj`. + Returns ------- - obj : instance of `Streamlines` - Returns an instance of a `Streamlines` class containing data and metadata - of streamlines loaded from `fileobj`. + tractogram_file : ``TractogramFile`` + Returns an instance of a `TractogramFile` class containing data and + metadata of the tractogram loaded from `fileobj`. ''' - streamlines_file = detect_format(fileobj) + tractogram_file = detect_format(fileobj) - if streamlines_file is None: - raise TypeError("Unknown format for 'fileobj': {0}!".format(fileobj)) + if tractogram_file is None: + raise TypeError("Unknown format for 'fileobj': {}".format(fileobj)) - return streamlines_file.load(fileobj, ref, lazy_load=lazy_load) + return tractogram_file.load(fileobj, lazy_load=lazy_load, ref=ref) -def save(streamlines, filename, ref=None): - ''' Saves a `Streamlines` object to a file +def save(tractogram_file, filename): + ''' Saves a tractogram to a file. Parameters ---------- - streamlines : `Streamlines` object - Streamlines to be saved. + tractogram_file : ``TractogramFile`` object + Tractogram to be saved on disk. filename : str - Name of the file where the streamlines will be saved. The format will - be guessed from `filename`. - - ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) - Reference space where streamlines will live in `fileobj`. + Name of the file where the tractogram will be saved. The format will + be guessed from `filename`. ''' - streamlines_file = detect_format(filename) + tractogram_file.save(filename) - if streamlines_file is None: - raise TypeError("Unknown streamlines file format: '{0}'!".format(filename)) - streamlines_file.save(streamlines, filename, ref) - - -def convert(in_fileobj, out_filename, ref): - ''' Converts a streamlines file to another format. +def save_tractogram(tractogram, filename, **kwargs): + ''' Saves a tractogram to a file. Parameters ---------- - in_fileobj : string or file-like object - If string, a filename; otherwise an open file-like object pointing - to a streamlines file (and ready to read from the beginning of the - header). + tractogram : ``Tractogram`` object + Tractogram to be saved. - out_filename : str - Name of the file where the streamlines will be saved. The format will - be guessed from `out_filename`. - - ref : filename | `Nifti1Image` object | 2D array (4,4) - Reference space where streamlines live in `fileobj`. + filename : str + Name of the file where the tractogram will be saved. The format will + be guessed from `filename`. ''' - streamlines = load(in_fileobj, ref, lazy_load=True) - save(streamlines, out_filename) + tractogram_file_class = detect_format(filename) + + if tractogram_file_class is None: + raise TypeError("Unknown tractogram file format: '{}'".format(filename)) + + tractogram_file = tractogram_file_class(tractogram, **kwargs) + tractogram_file.save(filename) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 42e34d0a37..4077685dc0 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -339,7 +339,6 @@ def create_from_generator(cls, gen): return cls(streamlines, scalars, properties) - @property def header(self): return self._header @@ -586,12 +585,25 @@ def transform(self, affine): return super(LazyTractogram, self).transform(affine, lazy=True) +class abstractclassmethod(classmethod): + __isabstractmethod__ = True + + def __init__(self, callable): + callable.__isabstractmethod__ = True + super(abstractclassmethod, self).__init__(callable) + + class TractogramFile(object): - ''' Convenience class to encapsulate streamlines file format. ''' + ''' Convenience class to encapsulate tractogram file format. ''' __metaclass__ = ABCMeta - def __init__(self, tractogram): - self.tractogram = tractogram + def __init__(self, tractogram, header=None): + self._tractogram = tractogram + self._header = TractogramHeader() if header is None else header + + @property + def tractogram(self): + return self._tractogram @property def streamlines(self): @@ -607,7 +619,7 @@ def properties(self): @property def header(self): - return self.tractogram.header + return self._header @classmethod def get_magic_number(cls): @@ -642,8 +654,8 @@ def is_correct_format(cls, fileobj): ''' raise NotImplementedError() - @staticmethod - def load(fileobj, ref, lazy_load=True): + @abstractclassmethod + def load(cls, fileobj, lazy_load=True): ''' Loads streamlines from a file-like object. Parameters @@ -653,54 +665,26 @@ def load(fileobj, ref, lazy_load=True): pointing to a streamlines file (and ready to read from the beginning of the header). - ref : filename | `Nifti1Image` object | 2D array (4,4) - Reference space where streamlines live in `fileobj`. - - lazy_load : boolean + lazy_load : boolean (optional) Load streamlines in a lazy manner i.e. they will not be kept in memory. For postprocessing speed, turn off this option. Returns ------- - streamlines : Tractogram object - Returns an object containing streamlines' data and header - information. See 'nibabel.Tractogram'. + tractogram_file : ``TractogramFile`` object + Returns an object containing tractogram data and header + information. ''' raise NotImplementedError() - @staticmethod - def save(streamlines, fileobj, ref=None): + @abstractmethod + def save(self, fileobj): ''' Saves streamlines to a file-like object. Parameters ---------- - streamlines : Tractogram object - Object containing streamlines' data and header information. - See 'nibabel.Tractogram'. - fileobj : string or file-like object If string, a filename; otherwise an open file-like object opened and ready to write. - - ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) - Reference space where streamlines will live in `fileobj`. - ''' - raise NotImplementedError() - - @staticmethod - def pretty_print(streamlines): - ''' Gets a formatted string of the header of a streamlines file format. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the - beginning of the header). - - Returns - ------- - info : string - Header information relevant to the streamlines file format. ''' raise NotImplementedError() diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 5cd30c65f6..e78be90429 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -12,7 +12,7 @@ from nibabel.testing import assert_arrays_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true, assert_false -from ..base_format import Tractogram, LazyTractogram +from ..base_format import Tractogram, LazyTractogram, TractogramFile from ..base_format import HeaderError, UsageWarning from ..header import Field from .. import trk @@ -134,28 +134,32 @@ def setUp(self): def test_load_empty_file(self): for empty_filename in self.empty_filenames: - tractogram = nib.streamlines.load(empty_filename, - ref=self.to_world_space, - lazy_load=False) - assert_true(type(tractogram), Tractogram) - check_tractogram(tractogram, 0, [], [], []) + tractogram_file = nib.streamlines.load(empty_filename, + lazy_load=False, + ref=self.to_world_space) + assert_true(isinstance(tractogram_file, TractogramFile)) + assert_true(type(tractogram_file.tractogram), Tractogram) + check_tractogram(tractogram_file.tractogram, 0, [], [], []) def test_load_simple_file(self): for simple_filename in self.simple_filenames: - tractogram = nib.streamlines.load(simple_filename, - ref=self.to_world_space, - lazy_load=False) - assert_true(type(tractogram), Tractogram) - check_tractogram(tractogram, self.nb_streamlines, - self.streamlines, [], []) + tractogram_file = nib.streamlines.load(simple_filename, + lazy_load=False, + ref=self.to_world_space) + assert_true(isinstance(tractogram_file, TractogramFile)) + assert_true(type(tractogram_file.tractogram), Tractogram) + check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + self.streamlines, [], []) # Test lazy_load - tractogram = nib.streamlines.load(simple_filename, - ref=self.to_world_space, - lazy_load=True) - assert_true(type(tractogram), LazyTractogram) - check_tractogram(tractogram, self.nb_streamlines, - self.streamlines, [], []) + tractogram_file = nib.streamlines.load(simple_filename, + lazy_load=True, + ref=self.to_world_space) + + assert_true(isinstance(tractogram_file, TractogramFile)) + assert_true(type(tractogram_file.tractogram), LazyTractogram) + check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + self.streamlines, [], []) def test_load_complex_file(self): for complex_filename in self.complex_filenames: @@ -169,20 +173,22 @@ def test_load_complex_file(self): if file_format.can_save_properties(): properties = self.mean_curvature_torsion - tractogram = nib.streamlines.load(complex_filename, - ref=self.to_world_space, - lazy_load=False) - assert_true(type(tractogram), Tractogram) - check_tractogram(tractogram, self.nb_streamlines, - self.streamlines, scalars, properties) + tractogram_file = nib.streamlines.load(complex_filename, + lazy_load=False, + ref=self.to_world_space) + assert_true(isinstance(tractogram_file, TractogramFile)) + assert_true(type(tractogram_file.tractogram), Tractogram) + check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + self.streamlines, scalars, properties) # Test lazy_load - tractogram = nib.streamlines.load(complex_filename, - ref=self.to_world_space, - lazy_load=True) - assert_true(type(tractogram), LazyTractogram) - check_tractogram(tractogram, self.nb_streamlines, - self.streamlines, scalars, properties) + tractogram_file = nib.streamlines.load(complex_filename, + lazy_load=True, + ref=self.to_world_space) + assert_true(isinstance(tractogram_file, TractogramFile)) + assert_true(type(tractogram_file.tractogram), LazyTractogram) + check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + self.streamlines, scalars, properties) def test_save_simple_file(self): tractogram = Tractogram(self.streamlines) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 40720027fc..e1584731db 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -12,7 +12,7 @@ from ..base_format import Tractogram, LazyTractogram from ..base_format import DataError, HeaderError, HeaderWarning, UsageWarning -from .. import trk +#from .. import trk from ..trk import TrkFile DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') @@ -57,30 +57,31 @@ def setUp(self): self.nb_streamlines = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) + self.affine = np.eye(4) def test_load_empty_file(self): trk = TrkFile.load(self.empty_trk_filename, ref=None, lazy_load=False) - check_tractogram(trk, 0, [], [], []) + check_tractogram(trk.tractogram, 0, [], [], []) trk = TrkFile.load(self.empty_trk_filename, ref=None, lazy_load=True) # Suppress warning about loading a TRK file in lazy mode with count=0. with suppress_warnings(): - check_tractogram(trk, 0, [], [], []) + check_tractogram(trk.tractogram, 0, [], [], []) def test_load_simple_file(self): trk = TrkFile.load(self.simple_trk_filename, ref=None, lazy_load=False) - check_tractogram(trk, self.nb_streamlines, self.streamlines, [], []) + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) trk = TrkFile.load(self.simple_trk_filename, ref=None, lazy_load=True) - check_tractogram(trk, self.nb_streamlines, self.streamlines, [], []) + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) def test_load_complex_file(self): trk = TrkFile.load(self.complex_trk_filename, ref=None, lazy_load=False) - check_tractogram(trk, self.nb_streamlines, + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, self.colors, self.mean_curvature_torsion) trk = TrkFile.load(self.complex_trk_filename, ref=None, lazy_load=True) - check_tractogram(trk, self.nb_streamlines, + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, self.colors, self.mean_curvature_torsion) def test_load_file_with_wrong_information(self): @@ -89,12 +90,12 @@ def test_load_file_with_wrong_information(self): # Simulate a TRK file where `count` was not provided. count = np.array(0, dtype="int32").tostring() new_trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] - tractogram = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) - check_tractogram(tractogram, self.nb_streamlines, self.streamlines, [], []) + trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) - tractogram = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) + trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) with clear_and_catch_warnings(record=True, modules=[base_format]) as w: - check_tractogram(tractogram, self.nb_streamlines, self.streamlines, [], []) + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) assert_equal(len(w), 1) assert_true(issubclass(w[0].category, UsageWarning)) @@ -121,15 +122,16 @@ def test_write_simple_file(self): tractogram = Tractogram(self.streamlines) trk_file = BytesIO() - TrkFile.save(tractogram, trk_file) + trk = TrkFile(tractogram, ref=self.affine) + trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_tractogram = TrkFile.load(trk_file) - check_tractogram(loaded_tractogram, self.nb_streamlines, - self.streamlines, [], []) + loaded_trk = TrkFile.load(trk_file) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, + self.streamlines, [], []) - loaded_tractogram_orig = TrkFile.load(self.simple_trk_filename) - assert_tractogram_equal(loaded_tractogram, loaded_tractogram_orig) + loaded_trk_orig = TrkFile.load(self.simple_trk_filename) + assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(open(self.simple_trk_filename, 'rb').read(), trk_file.read()) @@ -139,38 +141,40 @@ def test_write_complex_file(self): tractogram = Tractogram(self.streamlines, scalars=self.colors) trk_file = BytesIO() - TrkFile.save(tractogram, trk_file) + trk = TrkFile(tractogram, ref=self.affine) + trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_tractogram = TrkFile.load(trk_file, ref=None, lazy_load=False) - - check_tractogram(loaded_tractogram, self.nb_streamlines, - self.streamlines, self.colors, []) + loaded_trk = TrkFile.load(trk_file, lazy_load=False) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, + self.streamlines, self.colors, []) # With properties tractogram = Tractogram(self.streamlines, properties=self.mean_curvature_torsion) trk_file = BytesIO() - TrkFile.save(tractogram, trk_file) + trk = TrkFile(tractogram, ref=self.affine) + trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_tractogram = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_tractogram(loaded_tractogram, self.nb_streamlines, - self.streamlines, [], self.mean_curvature_torsion) + loaded_trk = TrkFile.load(trk_file, lazy_load=False) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, + self.streamlines, [], self.mean_curvature_torsion) # With scalars and properties tractogram = Tractogram(self.streamlines, scalars=self.colors, properties=self.mean_curvature_torsion) trk_file = BytesIO() - TrkFile.save(tractogram, trk_file) + trk = TrkFile(tractogram, ref=self.affine) + trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_tractogram = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_tractogram(loaded_tractogram, self.nb_streamlines, - self.streamlines, self.colors, self.mean_curvature_torsion) + loaded_trk = TrkFile.load(trk_file, lazy_load=False) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, + self.streamlines, self.colors, self.mean_curvature_torsion) - loaded_tractogram_orig = TrkFile.load(self.complex_trk_filename) - assert_tractogram_equal(loaded_tractogram, loaded_tractogram_orig) + loaded_trk_orig = TrkFile.load(self.complex_trk_filename) + assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(open(self.complex_trk_filename, 'rb').read(), trk_file.read()) @@ -182,14 +186,16 @@ def test_write_erroneous_file(self): [(0, 0, 1)]] tractogram = Tractogram(self.streamlines, scalars) - assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) + trk = TrkFile(tractogram, ref=self.affine) + assert_raises(DataError, trk.save, BytesIO()) # No scalars for every streamlines scalars = [[(1, 0, 0)]*1, [(0, 1, 0)]*2] tractogram = Tractogram(self.streamlines, scalars) - assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) + trk = TrkFile(tractogram, ref=self.affine) + assert_raises(DataError, trk.save, BytesIO()) # # Unit test moved to test_base_format.py # # Inconsistent number of scalars between points @@ -215,29 +221,32 @@ def test_write_erroneous_file(self): np.array([2.11], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] tractogram = Tractogram(self.streamlines, properties=properties) - assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) + trk = TrkFile(tractogram, ref=self.affine) + assert_raises(DataError, trk.save, BytesIO()) # Unit test moved to test_base_format.py # No properties for every streamlines properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11, 2.22], dtype="f4")] tractogram = Tractogram(self.streamlines, properties=properties) - assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) - - def test_write_file_lazy_tractogram(self): - streamlines = lambda: (point for point in self.streamlines) - scalars = lambda: (scalar for scalar in self.colors) - properties = lambda: (prop for prop in self.mean_curvature_torsion) - - tractogram = LazyTractogram(streamlines, scalars, properties) - # No need to manually set `nb_streamlines` in the header since we count - # them as writing. - #tractogram.header.nb_streamlines = self.nb_streamlines - - trk_file = BytesIO() - TrkFile.save(tractogram, trk_file) - trk_file.seek(0, os.SEEK_SET) - - trk = TrkFile.load(trk_file, ref=None, lazy_load=False) - check_tractogram(trk, self.nb_streamlines, - self.streamlines, self.colors, self.mean_curvature_torsion) + trk = TrkFile(tractogram, ref=self.affine) + assert_raises(DataError, trk.save, BytesIO()) + + # def test_write_file_lazy_tractogram(self): + # streamlines = lambda: (point for point in self.streamlines) + # scalars = lambda: (scalar for scalar in self.colors) + # properties = lambda: (prop for prop in self.mean_curvature_torsion) + + # tractogram = LazyTractogram(streamlines, scalars, properties) + # # No need to manually set `nb_streamlines` in the header since we count + # # them as writing. + # #tractogram.header.nb_streamlines = self.nb_streamlines + + # trk_file = BytesIO() + # trk = TrkFile(tractogram, ref=self.affine) + # trk.save(trk_file) + # trk_file.seek(0, os.SEEK_SET) + + # trk = TrkFile.load(trk_file, ref=None, lazy_load=False) + # check_tractogram(trk.tractogram, self.nb_streamlines, + # self.streamlines, self.colors, self.mean_curvature_torsion) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index ae913ad67f..1042e5f9af 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -8,11 +8,11 @@ import warnings import numpy as np +import nibabel as nib from nibabel.openers import Opener from nibabel.volumeutils import (native_code, swapped_code) -from nibabel.streamlines.base_format import CompactList from nibabel.streamlines.base_format import TractogramFile from nibabel.streamlines.base_format import DataError, HeaderError, HeaderWarning from nibabel.streamlines.base_format import Tractogram, LazyTractogram @@ -319,6 +319,27 @@ class TrkFile(TractogramFile): MAGIC_NUMBER = b"TRACK" HEADER_SIZE = 1000 + def __init__(self, tractogram, ref, header=None): + """ + Parameters + ---------- + tractogram : ``Tractogram`` object + Tractogram that will be contained in this ``TrkFile``. + + ref : filename | `Nifti1Image` object | 2D array (4,4) + Reference space where streamlines live in. + + header : ``TractogramHeader`` file + Metadata associated to this tractogram file. + + Notes + ----- + Streamlines of the tractogram are assumed to be in *RAS+* and *mm* space + where coordinate (0,0,0) refers to the center of the voxel. + """ + super(TrkFile, self).__init__(tractogram, header) + self._affine = get_affine_from_reference(ref) + @classmethod def get_magic_number(cls): ''' Return TRK's magic number. ''' @@ -357,8 +378,8 @@ def is_correct_format(cls, fileobj): return False - @staticmethod - def load(fileobj, ref=None, lazy_load=False): + @classmethod + def load(cls, fileobj, lazy_load=False, ref=None): ''' Loads streamlines from a file-like object. Parameters @@ -368,30 +389,44 @@ def load(fileobj, ref=None, lazy_load=False): pointing to TRK file (and ready to read from the beginning of the TRK header). - ref : filename | `Nifti1Image` object | 2D array (4,4) | None - Reference space where streamlines live in `fileobj`. - lazy_load : boolean (optional) Load streamlines in a lazy manner i.e. they will not be kept in memory. + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines live in `fileobj`. + Returns ------- - tractogram : Tractogram object - Returns an object containing tractogram' data and header - information. See `nibabel.Tractogram`. + trk_file : ``TrkFile`` object + Returns an object containing tractogram data and header + information. Notes ----- - Tractogram are assumed to be in voxel space where coordinate (0,0,0) - refers to the center of the voxel. + Streamlines of the returned tractogram are assumed to be in RASmm + space where coordinate (0,0,0) refers to the center of the voxel. ''' trk_reader = TrkReader(fileobj) # TRK's streamlines are in 'voxelmm' space, we send them to rasmm. - affine = trk_reader.header[Field.to_world_space] + # First send them to voxel space. + affine = np.eye(4) affine[range(3), range(3)] /= trk_reader.header[Field.VOXEL_SIZES] + # If voxel order implied from the affine does not match the voxel + # order save in the TRK header, change the orientation. + header_ornt = trk_reader.header[Field.VOXEL_ORDER] + affine_ornt = "".join(nib.orientations.aff2axcodes(trk_reader.header[Field.to_world_space])) + header_ornt = nib.orientations.axcodes2ornt(header_ornt) + affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) + ornt = nib.orientations.ornt_transform(header_ornt, affine_ornt) + M = nib.orientations.inv_ornt_aff(ornt, trk_reader.header[Field.DIMENSIONS]) + affine = np.dot(M, affine) + + # Applied the affine going from voxel space to rasmm. + affine = np.dot(trk_reader.header[Field.to_world_space], affine) + # TrackVis considers coordinate (0,0,0) to be the corner of the # voxel whereas streamlines returned assume (0,0,0) to be the # center of the voxel. Thus, streamlines are shifted of half @@ -494,36 +529,29 @@ def _apply_transform(trk_reader): #if tractogram.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: # raise HeaderError("'nb_properties_per_streamline' does not match.") - return tractogram + return cls(tractogram, ref=affine, header=trk_reader.header) - @staticmethod - def save(tractogram, fileobj, ref=None): - ''' Saves tractogram to a file-like object. + def save(self, fileobj): + ''' Saves tractogram to a file-like object using TRK format. Parameters ---------- - tractogram : Tractogram object - Object containing tractogram' data and header information. - See 'nibabel.Tractogram'. - fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to TRK file (and ready to read from the beginning of the TRK header data). - - ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) - Reference space where streamlines will live in `fileobj`. - - Notes - ----- - Tractogram are assumed to be in voxel space where coordinate (0,0,0) - refers to the center of the voxel. ''' - if ref is not None: - tractogram.header.to_world_space = get_affine_from_reference(ref) + # Update header using the tractogram. + self.header.nb_scalars_per_point = 0 + if self.tractogram.scalars.shape is not None: + self.header.nb_scalars_per_point = len(self.tractogram.scalars[0]) + + self.header.nb_properties_per_streamline = 0 + if self.tractogram.properties.shape is not None: + self.header.nb_properties_per_streamline = len(self.tractogram.properties[0]) - trk_writer = TrkWriter(fileobj, tractogram.header) - trk_writer.write(tractogram) + trk_writer = TrkWriter(fileobj, self.header) + trk_writer.write(self.tractogram) @staticmethod def pretty_print(fileobj): From 8dd08cc24a53c8553468a36cc88437de1ef51e4c Mon Sep 17 00:00:00 2001 From: Eleftherios Garyfallidis Date: Mon, 2 Nov 2015 17:09:54 -0500 Subject: [PATCH 016/135] NF: Tractogram is now more specific and supports multiple keys for scalars and properties which are now called data_per_steamlines and data_per_points respectively --- nibabel/streamlines/base_format.py | 290 ++++++------------ nibabel/streamlines/tests/test_base_format.py | 15 + 2 files changed, 110 insertions(+), 195 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 4077685dc0..bf9cea59a7 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -194,10 +194,10 @@ def __repr__(self): class TractogramItem(object): - ''' Class containing information about one streamline. + """ Class containing information about one streamline. - ``TractogramItem`` objects have three main properties: `streamline`, `scalars` - and ``properties``. + ``TractogramItem`` objects have three main properties: `streamline`, + `data_for_streamline`, and `data_for_points`. Parameters ---------- @@ -205,22 +205,16 @@ class TractogramItem(object): Points of this streamline represented as an ndarray of shape (N, 3) where N is the number of points. - scalars : ndarray of shape (N, M) - Scalars associated with each point of this streamline and represented - as an ndarray of shape (N, M) where N is the number of points and - M is the number of scalars (excluding the three coordinates). + data_for_streamline : dict - properties : ndarray of shape (P,) - Properties associated with this streamline and represented as an - ndarray of shape (P,) where P is the number of properties. - ''' - def __init__(self, streamline, scalars=None, properties=None): - #if scalars is not None and len(streamline) != len(scalars): - # raise ValueError("First dimension of streamline and scalars must match.") + data_for_points : dict + """ + def __init__(self, streamline, + data_for_streamline, data_for_points): self.streamline = np.asarray(streamline) - self.scalars = np.asarray([] if scalars is None else scalars) - self.properties = np.asarray([] if properties is None else properties) + self.data_for_streamline = data_for_streamline + self.data_for_points = data_for_points def __iter__(self): return iter(self.streamline) @@ -230,186 +224,77 @@ def __len__(self): class Tractogram(object): - ''' Class containing information about streamlines. + """ Class containing information about streamlines. - Tractogram objects have three main properties: ``streamlines``, ``scalars`` - and ``properties``. Tractogram objects can be iterate over producing - tuple of ``streamlines``, ``scalars`` and ``properties`` for each streamline. + Tractogram objects have three main properties: ``streamlines`` Parameters ---------- streamlines : list of ndarray of shape (Nt, 3) - Sequence of T streamlines. One streamline is an ndarray of shape (Nt, 3) - where Nt is the number of points of streamline t. + Sequence of T streamlines. One streamline is an ndarray of shape + (Nt, 3) where Nt is the number of points of streamline t. + + data_per_streamline : dictionary of list of ndarray of shape (P,) + Sequence of T ndarrays of shape (P,) where T is the number of + streamlines defined by ``streamlines``, P is the number of properties + associated to each streamline. - scalars : list of ndarray of shape (Nt, M) + data_per_point : dictionary of list of ndarray of shape (Nt, M) Sequence of T ndarrays of shape (Nt, M) where T is the number of streamlines defined by ``streamlines``, Nt is the number of points for a particular streamline t and M is the number of scalars associated to each point (excluding the three coordinates). - properties : list of ndarray of shape (P,) - Sequence of T ndarrays of shape (P,) where T is the number of - streamlines defined by ``streamlines``, P is the number of properties - associated to each streamline. - ''' - def __init__(self, streamlines=None, scalars=None, properties=None): - self._header = TractogramHeader() - self.streamlines = streamlines - self.scalars = scalars - self.properties = properties + """ + def __init__(self, streamlines=None, + data_per_streamline=None, + data_per_point=None): - @classmethod - def create_from_generator(cls, gen): - BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. - - streamlines = CompactList() - scalars = CompactList() - properties = np.array([]) - - gen = iter(gen) - try: - first_element = next(gen) - gen = itertools.chain([first_element], gen) - except StopIteration: - return cls(streamlines, scalars, properties) - - # Allocated some buffer memory. - pts = np.asarray(first_element[0]) - scals = np.asarray(first_element[1]) - props = np.asarray(first_element[2]) - - scals_shape = scals.shape - props_shape = props.shape - - streamlines._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) - scalars._data = np.empty((BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) - properties = np.empty((BUFFER_SIZE, props.shape[0]), dtype=props.dtype) - - offset = 0 - for i, (pts, scals, props) in enumerate(gen): - pts = np.asarray(pts) - scals = np.asarray(scals) - props = np.asarray(props) - - if scals.shape[1] != scals_shape[1]: - raise ValueError("Number of scalars differs from one" - " point or streamline to another") - - if props.shape != props_shape: - raise ValueError("Number of properties differs from one" - " streamline to another") - - end = offset + len(pts) - if end >= len(streamlines._data): - # Resize is needed (at least `len(pts)` items will be added). - streamlines._data.resize((len(streamlines._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) - scalars._data.resize((len(scalars._data) + len(scals)+BUFFER_SIZE, scals.shape[1])) - - streamlines._offsets.append(offset) - streamlines._lengths.append(len(pts)) - streamlines._data[offset:offset+len(pts)] = pts - scalars._data[offset:offset+len(scals)] = scals - - offset += len(pts) - - if i >= len(properties): - properties.resize((len(properties) + BUFFER_SIZE, props.shape[0])) - - properties[i] = props - - # Clear unused memory. - streamlines._data.resize((offset, pts.shape[1])) - - if scals_shape[1] == 0: - # Because resizing an empty ndarray creates memory! - scalars._data = np.empty((offset, scals.shape[1])) + self.streamlines = streamlines + if data_per_streamline is None: + self.data_per_streamline = {} else: - scalars._data.resize((offset, scals.shape[1])) - - # Share offsets and lengths between streamlines and scalars. - scalars._offsets = streamlines._offsets - scalars._lengths = streamlines._lengths - - if props_shape[0] == 0: - # Because resizing an empty ndarray creates memory! - properties = np.empty((i+1, props.shape[0])) + self.data_per_streamline = data_per_streamline + if data_per_point is None: + self.data_per_point = {} else: - properties.resize((i+1, props.shape[0])) - - return cls(streamlines, scalars, properties) - - @property - def header(self): - return self._header - - @property - def streamlines(self): - return self._streamlines - - @streamlines.setter - def streamlines(self, value): - self._streamlines = value - if not isinstance(value, CompactList): - self._streamlines = CompactList(value) - - self.header.nb_streamlines = len(self.streamlines) - - @property - def scalars(self): - return self._scalars - - @scalars.setter - def scalars(self, value): - self._scalars = value - if not isinstance(value, CompactList): - self._scalars = CompactList(value) - - self.header.nb_scalars_per_point = 0 - if len(self.scalars) > 0 and len(self.scalars[0]) > 0: - self.header.nb_scalars_per_point = len(self.scalars[0][0]) - - @property - def properties(self): - return self._properties - - @properties.setter - def properties(self, value): - self._properties = np.asarray(value) - if value is None: - self._properties = np.empty((len(self), 0), dtype=np.float32) - - self.header.nb_properties_per_streamline = 0 - if len(self.properties) > 0: - self.header.nb_properties_per_streamline = len(self.properties[0]) + self.data_per_point = data_per_point def __iter__(self): - for data in zip_longest(self.streamlines, self.scalars, self.properties, fillvalue=None): - yield TractogramItem(*data) + for i in range(len(self.streamlines)): + yield self[i] def __getitem__(self, idx): pts = self.streamlines[idx] - scalars = [] - if len(self.scalars) > 0: - scalars = self.scalars[idx] - properties = [] - if len(self.properties) > 0: - properties = self.properties[idx] + new_data_per_streamline = {} + for key in self.data_per_streamline: + new_data_per_streamline[key] = self.data_per_streamline[key][idx] + + new_data_per_point = {} + for key in self.data_per_point: + new_data_per_point[key] = self.data_per_point[key][idx] if type(idx) is slice: - return Tractogram(pts, scalars, properties) + return Tractogram(pts, new_data_per_streamline, new_data_per_point) - return TractogramItem(pts, scalars, properties) - - def __len__(self): - return len(self.streamlines) + return TractogramItem(pts, new_data_per_streamline, new_data_per_point) def copy(self): """ Returns a copy of this `Tractogram` object. """ - streamlines = Tractogram(self.streamlines.copy(), self.scalars.copy(), self.properties.copy()) - streamlines._header = self.header.copy() - return streamlines + + new_data_per_streamline = {} + for key in self.data_per_streamline: + new_data_per_streamline[key] = self.data_per_streamline[key].copy() + + new_data_per_point = {} + for key in self.data_per_point: + new_data_per_point[key] = self.data_per_point[key].copy() + + tractogram = Tractogram(self.streamlines.copy(), + new_data_per_streamline, + new_data_per_point) + return tractogram def apply_affine(self, affine): """ Applies an affine transformation on the points of each streamline. @@ -585,25 +470,12 @@ def transform(self, affine): return super(LazyTractogram, self).transform(affine, lazy=True) -class abstractclassmethod(classmethod): - __isabstractmethod__ = True - - def __init__(self, callable): - callable.__isabstractmethod__ = True - super(abstractclassmethod, self).__init__(callable) - - class TractogramFile(object): - ''' Convenience class to encapsulate tractogram file format. ''' + ''' Convenience class to encapsulate streamlines file format. ''' __metaclass__ = ABCMeta - def __init__(self, tractogram, header=None): - self._tractogram = tractogram - self._header = TractogramHeader() if header is None else header - - @property - def tractogram(self): - return self._tractogram + def __init__(self, tractogram): + self.tractogram = tractogram @property def streamlines(self): @@ -619,7 +491,7 @@ def properties(self): @property def header(self): - return self._header + return self.tractogram.header @classmethod def get_magic_number(cls): @@ -654,8 +526,8 @@ def is_correct_format(cls, fileobj): ''' raise NotImplementedError() - @abstractclassmethod - def load(cls, fileobj, lazy_load=True): + @staticmethod + def load(fileobj, ref, lazy_load=True): ''' Loads streamlines from a file-like object. Parameters @@ -665,26 +537,54 @@ def load(cls, fileobj, lazy_load=True): pointing to a streamlines file (and ready to read from the beginning of the header). - lazy_load : boolean (optional) + ref : filename | `Nifti1Image` object | 2D array (4,4) + Reference space where streamlines live in `fileobj`. + + lazy_load : boolean Load streamlines in a lazy manner i.e. they will not be kept in memory. For postprocessing speed, turn off this option. Returns ------- - tractogram_file : ``TractogramFile`` object - Returns an object containing tractogram data and header - information. + streamlines : Tractogram object + Returns an object containing streamlines' data and header + information. See 'nibabel.Tractogram'. ''' raise NotImplementedError() - @abstractmethod - def save(self, fileobj): + @staticmethod + def save(streamlines, fileobj, ref=None): ''' Saves streamlines to a file-like object. Parameters ---------- + streamlines : Tractogram object + Object containing streamlines' data and header information. + See 'nibabel.Tractogram'. + fileobj : string or file-like object If string, a filename; otherwise an open file-like object opened and ready to write. + + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines will live in `fileobj`. + ''' + raise NotImplementedError() + + @staticmethod + def pretty_print(streamlines): + ''' Gets a formatted string of the header of a streamlines file format. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header). + + Returns + ------- + info : string + Header information relevant to the streamlines file format. ''' raise NotImplementedError() diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index 99d8c8ea90..ad22274012 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -251,10 +251,25 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] + self.mean_curvature = np.array([1.11, 2.11, 3.11], dtype="f4") + self.mean_color = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype="f4") + + self.nb_tractogram = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) + def test_tractogram_creation_with_dix(self): + + tractogram = Tractogram( + streamlines=self.streamlines, + data_per_streamline={'mean_curvature': self.mean_curvature, + 'mean_color': self.mean_color}, + data_per_point={'colors': self.colors}) + + tractogram[:2] + + def test_tractogram_creation_from_arrays(self): # Empty tractogram = Tractogram() From c7b7bab86db823110324f45f889d88026cd27d7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 2 Nov 2015 18:03:37 -0500 Subject: [PATCH 017/135] Fixed some unit tests for Tractogram and TractogramItem --- nibabel/streamlines/base_format.py | 47 +++- nibabel/streamlines/tests/test_base_format.py | 247 +++++++----------- 2 files changed, 133 insertions(+), 161 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index bf9cea59a7..52563a5a97 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -209,9 +209,7 @@ class TractogramItem(object): data_for_points : dict """ - def __init__(self, streamline, - data_for_streamline, data_for_points): - + def __init__(self, streamline, data_for_streamline, data_for_points): self.streamline = np.asarray(streamline) self.data_for_streamline = data_for_streamline self.data_for_points = data_for_points @@ -251,15 +249,43 @@ def __init__(self, streamlines=None, data_per_point=None): self.streamlines = streamlines - if data_per_streamline is None: - self.data_per_streamline = {} - else: + + self.data_per_streamline = {} + if data_per_streamline is not None: self.data_per_streamline = data_per_streamline - if data_per_point is None: - self.data_per_point = {} - else: + + self.data_per_point = {} + if data_per_point is not None: self.data_per_point = data_per_point + @property + def streamlines(self): + return self._streamlines + + @streamlines.setter + def streamlines(self, value): + self._streamlines = CompactList(value) + + @property + def data_per_streamline(self): + return self._data_per_streamline + + @data_per_streamline.setter + def data_per_streamline(self, value): + self._data_per_streamline = {} + for k, v in value.items(): + self._data_per_streamline[k] = np.asarray(v) + + @property + def data_per_point(self): + return self._data_per_point + + @data_per_point.setter + def data_per_point(self, value): + self._data_per_point = {} + for k, v in value.items(): + self._data_per_point[k] = CompactList(v) + def __iter__(self): for i in range(len(self.streamlines)): yield self[i] @@ -280,6 +306,9 @@ def __getitem__(self, idx): return TractogramItem(pts, new_data_per_streamline, new_data_per_point) + def __len__(self): + return len(self.streamlines) + def copy(self): """ Returns a copy of this `Tractogram` object. """ diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index ad22274012..5ae8c0d5d8 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -3,7 +3,7 @@ import numpy as np import warnings -from nibabel.testing import assert_arrays_equal +from nibabel.testing import assert_arrays_equal, isiterable from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal, assert_array_almost_equal @@ -209,36 +209,31 @@ class TestTractogramItem(unittest.TestCase): def test_creating_tractogram_item(self): rng = np.random.RandomState(42) streamline = rng.rand(rng.randint(10, 50), 3) - scalars = rng.rand(len(streamline), 5) - properties = rng.rand(42) + colors = rng.rand(len(streamline), 3) + mean_curvature = 1.11 + mean_color = np.array([0, 1, 0], dtype="f4") - # Create a streamline with only points - t = TractogramItem(streamline) - assert_equal(len(t), len(streamline)) - assert_array_equal(t.scalars, []) - assert_array_equal(t.properties, []) + data_for_streamline = {"mean_curvature": mean_curvature, + "mean_color": mean_color} + + data_for_points = {"colors": colors} - # Create a streamline with points, scalars and properties. - t = TractogramItem(streamline, scalars, properties) + # Create a tractogram item with a streamline, data. + t = TractogramItem(streamline, data_for_streamline, data_for_points) assert_equal(len(t), len(streamline)) assert_array_equal(t.streamline, streamline) assert_array_equal(list(t), streamline) - assert_equal(len(t), len(scalars)) - assert_array_equal(t.scalars, scalars) - assert_array_equal(t.properties, properties) - - # # Create a streamline with different number of scalars. - # scalars = rng.rand(len(streamline)+3, 5) - # assert_raises(ValueError, TractogramItem, streamline, scalars) + assert_array_equal(t.data_for_streamline['mean_curvature'], + mean_curvature) + assert_array_equal(t.data_for_streamline['mean_color'], + mean_color) + assert_array_equal(t.data_for_points['colors'], + colors) class TestTractogram(unittest.TestCase): def setUp(self): - self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") - self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") - self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") - self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), np.arange(2*3, dtype="f4").reshape((2, 3)), np.arange(5*3, dtype="f4").reshape((5, 3))] @@ -247,77 +242,56 @@ def setUp(self): np.array([(0, 1, 0)]*2, dtype="f4"), np.array([(0, 0, 1)]*5, dtype="f4")] - self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), - np.array([2.11, 2.22], dtype="f4"), - np.array([3.11, 3.22], dtype="f4")] - self.mean_curvature = np.array([1.11, 2.11, 3.11], dtype="f4") - self.mean_color = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]], dtype="f4") - + self.mean_color = np.array([[0, 1, 0], + [0, 0, 1], + [1, 0, 0]], dtype="f4") self.nb_tractogram = len(self.streamlines) - self.nb_scalars_per_point = self.colors[0].shape[1] - self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) - - def test_tractogram_creation_with_dix(self): - - tractogram = Tractogram( - streamlines=self.streamlines, - data_per_streamline={'mean_curvature': self.mean_curvature, - 'mean_color': self.mean_color}, - data_per_point={'colors': self.colors}) - - tractogram[:2] - - def test_tractogram_creation_from_arrays(self): - # Empty + def test_tractogram_creation(self): + # Create an empty tractogram. tractogram = Tractogram() assert_equal(len(tractogram), 0) assert_arrays_equal(tractogram.streamlines, []) - assert_arrays_equal(tractogram.scalars, []) - assert_arrays_equal(tractogram.properties, []) + assert_equal(tractogram.data_per_streamline, {}) + assert_equal(tractogram.data_per_point, {}) + assert_true(isiterable(tractogram)) - # Check if we can iterate through the tractogram. - for streamline in tractogram: - pass - - # Only streamlines + # Create a tractogram with only streamlines tractogram = Tractogram(streamlines=self.streamlines) assert_equal(len(tractogram), len(self.streamlines)) assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.scalars, []) - assert_arrays_equal(tractogram.properties, []) + assert_equal(tractogram.data_per_streamline, {}) + assert_equal(tractogram.data_per_point, {}) + assert_true(isiterable(tractogram)) - # Check if we can iterate through the tractogram. - for streamline in tractogram: - pass + # Create a tractogram with streamlines and other data. + tractogram = Tractogram( + self.streamlines, + data_per_streamline={'mean_curvature': self.mean_curvature, + 'mean_color': self.mean_color}, + data_per_point={'colors': self.colors}) - # Points, scalars and properties - tractogram = Tractogram(self.streamlines, self.colors, self.mean_curvature_torsion) assert_equal(len(tractogram), len(self.streamlines)) assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.scalars, self.colors) - assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) + assert_arrays_equal(tractogram.data_per_streamline['mean_curvature'], + self.mean_curvature) + assert_arrays_equal(tractogram.data_per_streamline['mean_color'], + self.mean_color) + assert_arrays_equal(tractogram.data_per_point['colors'], + self.colors) - # Check if we can iterate through the tractogram. - for streamline in tractogram: - pass + assert_true(isiterable(tractogram)) # Inconsistent number of scalars between streamlines - scalars = [[(1, 0, 0)]*1, - [(0, 1, 0), (0, 1)], - [(0, 0, 1)]*5] - - assert_raises(ValueError, Tractogram, self.streamlines, scalars) - - # Unit test moved to test_base_format.py - # Inconsistent number of scalars between tractogram - scalars = [[(1, 0, 0)]*1, - [(0, 1)]*2, - [(0, 0, 1)]*5] + wrong_data = [[(1, 0, 0)]*1, + [(0, 1, 0), (0, 1)], + [(0, 0, 1)]*5] - assert_raises(ValueError, Tractogram, self.streamlines, scalars) + data_per_point = {'wrong_data': wrong_data} + assert_raises(ValueError, Tractogram, self.streamlines, + data_per_point=data_per_point) def test_tractogram_getter(self): # Tractogram with only streamlines @@ -326,78 +300,71 @@ def test_tractogram_getter(self): selected_tractogram = tractogram[::2] assert_equal(len(selected_tractogram), (len(self.streamlines)+1)//2) - assert_arrays_equal(selected_tractogram.streamlines, self.streamlines[::2]) - assert_equal(sum(map(len, selected_tractogram.scalars)), 0) - assert_equal(sum(map(len, selected_tractogram.properties)), 0) + assert_arrays_equal(selected_tractogram.streamlines, + self.streamlines[::2]) + assert_equal(tractogram.data_per_streamline, {}) + assert_equal(tractogram.data_per_point, {}) - # Tractogram with streamlines, scalars and properties - tractogram = Tractogram(self.streamlines, self.colors, self.mean_curvature_torsion) + # Create a tractogram with streamlines and other data. + tractogram = Tractogram( + self.streamlines, + data_per_streamline={'mean_curvature': self.mean_curvature, + 'mean_color': self.mean_color}, + data_per_point={'colors': self.colors}) # Retrieve tractogram by their index for i, t in enumerate(tractogram): assert_array_equal(t.streamline, tractogram[i].streamline) - assert_array_equal(t.scalars, tractogram[i].scalars) - assert_array_equal(t.properties, tractogram[i].properties) + assert_array_equal(t.data_for_points['colors'], + tractogram[i].data_for_points['colors']) + + assert_array_equal(t.data_for_streamline['mean_curvature'], + tractogram[i].data_for_streamline['mean_curvature']) + + assert_array_equal(t.data_for_streamline['mean_color'], + tractogram[i].data_for_streamline['mean_color']) # Use slicing r_tractogram = tractogram[::-1] assert_arrays_equal(r_tractogram.streamlines, self.streamlines[::-1]) - assert_arrays_equal(r_tractogram.scalars, self.colors[::-1]) - assert_arrays_equal(r_tractogram.properties, self.mean_curvature_torsion[::-1]) - def test_tractogram_creation_from_generator(self): - # Create `Tractogram` from a generator yielding 3-tuples - gen = (x for x in zip(self.streamlines, self.colors, self.mean_curvature_torsion)) - - tractogram = Tractogram.create_from_generator(gen) - with suppress_warnings(): - assert_equal(len(tractogram), self.nb_tractogram) + assert_arrays_equal(r_tractogram.data_per_streamline['mean_curvature'], + self.mean_curvature[::-1]) + assert_arrays_equal(r_tractogram.data_per_streamline['mean_color'], + self.mean_color[::-1]) + assert_arrays_equal(r_tractogram.data_per_point['colors'], + self.colors[::-1]) - assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.scalars, self.colors) - assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) - - # Check if we can iterate through the tractogram. - for streamline in tractogram: - pass - - def test_tractogram_creation_from_coroutines(self): - # Points, scalars and properties - streamlines = lambda: (x for x in self.streamlines) - scalars = lambda: (x for x in self.colors) - properties = lambda: (x for x in self.mean_curvature_torsion) - - # To create tractogram from multiple coroutines use `LazyTractogram`. - assert_raises(TypeError, Tractogram, streamlines, scalars, properties) + def test_tractogram_add_new_data(self): + # Tractogram with only streamlines + tractogram = Tractogram(streamlines=self.streamlines) - def test_header(self): - # Empty Tractogram, with default header - tractogram = Tractogram() - assert_equal(tractogram.header.nb_streamlines, 0) - assert_equal(tractogram.header.nb_scalars_per_point, 0) - assert_equal(tractogram.header.nb_properties_per_streamline, 0) - assert_array_equal(tractogram.header.voxel_sizes, (1, 1, 1)) - assert_array_equal(tractogram.header.to_world_space, np.eye(4)) - assert_equal(tractogram.header.extra, {}) + tractogram.data_per_streamline['mean_curvature'] = self.mean_curvature + tractogram.data_per_streamline['mean_color'] = self.mean_color + tractogram.data_per_point['colors'] = self.colors - tractogram = Tractogram(self.streamlines, self.colors, self.mean_curvature_torsion) + # Retrieve tractogram by their index + for i, t in enumerate(tractogram): + assert_array_equal(t.streamline, tractogram[i].streamline) + assert_array_equal(t.data_for_points['colors'], + tractogram[i].data_for_points['colors']) - assert_equal(tractogram.header.nb_streamlines, len(self.streamlines)) - assert_equal(tractogram.header.nb_scalars_per_point, self.colors[0].shape[1]) - assert_equal(tractogram.header.nb_properties_per_streamline, self.mean_curvature_torsion[0].shape[0]) + assert_array_equal(t.data_for_streamline['mean_curvature'], + tractogram[i].data_for_streamline['mean_curvature']) - # Modifying voxel_sizes should be reflected in to_world_space - tractogram.header.voxel_sizes = (2, 3, 4) - assert_array_equal(tractogram.header.voxel_sizes, (2, 3, 4)) - assert_array_equal(np.diag(tractogram.header.to_world_space), (2, 3, 4, 1)) + assert_array_equal(t.data_for_streamline['mean_color'], + tractogram[i].data_for_streamline['mean_color']) - # Modifying scaling of to_world_space should be reflected in voxel_sizes - tractogram.header.to_world_space = np.diag([4, 3, 2, 1]) - assert_array_equal(tractogram.header.voxel_sizes, (4, 3, 2)) - assert_array_equal(tractogram.header.to_world_space, np.diag([4, 3, 2, 1])) + # Use slicing + r_tractogram = tractogram[::-1] + assert_arrays_equal(r_tractogram.streamlines, self.streamlines[::-1]) - # Test that we can run __repr__ without error. - repr(tractogram.header) + assert_arrays_equal(r_tractogram.data_per_streamline['mean_curvature'], + self.mean_curvature[::-1]) + assert_arrays_equal(r_tractogram.data_per_streamline['mean_color'], + self.mean_color[::-1]) + assert_arrays_equal(r_tractogram.data_per_point['colors'], + self.colors[::-1]) class TestLazyTractogram(unittest.TestCase): @@ -553,27 +520,3 @@ def test_lazy_tractogram_len(self): # This should *not* produce a warning. assert_equal(len(tractogram), 1234) assert_equal(len(w), 0) - - def test_lazy_tractogram_header(self): - # Empty `LazyTractogram`, with default header - tractogram = LazyTractogram() - assert_true(tractogram.header.nb_streamlines is None) - assert_equal(tractogram.header.nb_scalars_per_point, 0) - assert_equal(tractogram.header.nb_properties_per_streamline, 0) - assert_array_equal(tractogram.header.voxel_sizes, (1, 1, 1)) - assert_array_equal(tractogram.header.to_world_space, np.eye(4)) - assert_equal(tractogram.header.extra, {}) - - streamlines = lambda: (x for x in self.streamlines) - scalars = lambda: (x for x in self.colors) - properties = lambda: (x for x in self.mean_curvature_torsion) - tractogram = LazyTractogram(streamlines) - header = tractogram.header - - assert_equal(header.nb_scalars_per_point, 0) - tractogram.scalars = scalars - assert_equal(header.nb_scalars_per_point, self.nb_scalars_per_point) - - assert_equal(header.nb_properties_per_streamline, 0) - tractogram.properties = properties - assert_equal(header.nb_properties_per_streamline, self.nb_properties_per_streamline) From bea15bebe9316029ef4d2044bcb7f6b84ff341e4 Mon Sep 17 00:00:00 2001 From: Eleftherios Garyfallidis Date: Mon, 2 Nov 2015 18:14:31 -0500 Subject: [PATCH 018/135] Updated TractogramFile --- nibabel/streamlines/base_format.py | 93 ++++++++++-------------------- 1 file changed, 30 insertions(+), 63 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 52563a5a97..e1138420e3 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -499,55 +499,54 @@ def transform(self, affine): return super(LazyTractogram, self).transform(affine, lazy=True) +class abstractclassmethod(classmethod): + __isabstractmethod__ = True + + def __init__(self, callable): + callable.__isabstractmethod__ = True + super(abstractclassmethod, self).__init__(callable) + + class TractogramFile(object): - ''' Convenience class to encapsulate streamlines file format. ''' + ''' Convenience class to encapsulate tractogram file format. ''' __metaclass__ = ABCMeta - def __init__(self, tractogram): - self.tractogram = tractogram + def __init__(self, tractogram, header=None): + self._tractogram = tractogram + self._header = TractogramHeader() if header is None else header + + @property + def tractogram(self): + return self._tractogram @property def streamlines(self): return self.tractogram.streamlines @property - def scalars(self): - return self.tractogram.scalars + def header(self): + return self._header - @property - def properties(self): - return self.tractogram.properties + def get_tractogram(self): + return self.tractogram - @property - def header(self): - return self.tractogram.header + def get_header(self): + return self.header @classmethod def get_magic_number(cls): ''' Returns streamlines file's magic number. ''' raise NotImplementedError() - @classmethod - def can_save_scalars(cls): - ''' Tells if the streamlines format supports saving scalars. ''' - raise NotImplementedError() - - @classmethod - def can_save_properties(cls): - ''' Tells if the streamlines format supports saving properties. ''' - raise NotImplementedError() - @classmethod def is_correct_format(cls, fileobj): ''' Checks if the file has the right streamlines file format. - Parameters ---------- fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the beginning of the header). - Returns ------- is_correct_format : boolean @@ -555,65 +554,33 @@ def is_correct_format(cls, fileobj): ''' raise NotImplementedError() - @staticmethod - def load(fileobj, ref, lazy_load=True): + @abstractclassmethod + def load(cls, fileobj, lazy_load=True): ''' Loads streamlines from a file-like object. - Parameters ---------- fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the beginning of the header). - - ref : filename | `Nifti1Image` object | 2D array (4,4) - Reference space where streamlines live in `fileobj`. - - lazy_load : boolean + lazy_load : boolean (optional) Load streamlines in a lazy manner i.e. they will not be kept in memory. For postprocessing speed, turn off this option. - Returns ------- - streamlines : Tractogram object - Returns an object containing streamlines' data and header - information. See 'nibabel.Tractogram'. + tractogram_file : ``TractogramFile`` object + Returns an object containing tractogram data and header + information. ''' raise NotImplementedError() - @staticmethod - def save(streamlines, fileobj, ref=None): + @abstractmethod + def save(self, fileobj): ''' Saves streamlines to a file-like object. - Parameters ---------- - streamlines : Tractogram object - Object containing streamlines' data and header information. - See 'nibabel.Tractogram'. - fileobj : string or file-like object If string, a filename; otherwise an open file-like object opened and ready to write. - - ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) - Reference space where streamlines will live in `fileobj`. - ''' - raise NotImplementedError() - - @staticmethod - def pretty_print(streamlines): - ''' Gets a formatted string of the header of a streamlines file format. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the - beginning of the header). - - Returns - ------- - info : string - Header information relevant to the streamlines file format. ''' raise NotImplementedError() From 7399237ef39da6a27b63c2f7a794ca22b91906aa Mon Sep 17 00:00:00 2001 From: Eleftherios Garyfallidis Date: Mon, 2 Nov 2015 18:17:02 -0500 Subject: [PATCH 019/135] minor cleanup --- nibabel/streamlines/trk.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 1042e5f9af..61862f7a0e 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -345,16 +345,6 @@ def get_magic_number(cls): ''' Return TRK's magic number. ''' return cls.MAGIC_NUMBER - @classmethod - def can_save_scalars(cls): - ''' Tells if the streamlines format supports saving scalars. ''' - return True - - @classmethod - def can_save_properties(cls): - ''' Tells if the streamlines format supports saving properties. ''' - return True - @classmethod def is_correct_format(cls, fileobj): ''' Check if the file is in TRK format. From 0994547e7ea9b7d588d7df896210079ec2757a17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 2 Nov 2015 23:06:44 -0500 Subject: [PATCH 020/135] Changed unit tests to reflect modifications made to Tractogram --- nibabel/streamlines/__init__.py | 2 +- nibabel/streamlines/base_format.py | 39 ++- nibabel/streamlines/header.py | 4 +- nibabel/streamlines/tests/test_base_format.py | 9 + nibabel/streamlines/tests/test_streamlines.py | 127 +++++---- nibabel/streamlines/tests/test_trk.py | 142 ++++++---- nibabel/streamlines/trk.py | 264 ++++++++++-------- 7 files changed, 344 insertions(+), 243 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index a58f72901c..cfd8990741 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -90,7 +90,7 @@ def load(fileobj, lazy_load=False, ref=None): if tractogram_file is None: raise TypeError("Unknown format for 'fileobj': {}".format(fileobj)) - return tractogram_file.load(fileobj, lazy_load=lazy_load, ref=ref) + return tractogram_file.load(fileobj, lazy_load=lazy_load) def save(tractogram_file, filename): diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index e1138420e3..9276f01f2c 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -38,13 +38,24 @@ def __init__(self, iterable=None): iterable : iterable (optional) If specified, create a ``CompactList`` object initialized from iterable's items. Otherwise, create an empty ``CompactList``. + + Notes + ----- + If `iterable` is a ``CompactList`` object, a view is returned and no + memory is allocated. For an actual copy use the `.copy()` method. """ # Create new empty `CompactList` object. self._data = None self._offsets = [] self._lengths = [] - if iterable is not None: + if isinstance(iterable, CompactList): + # Create a view. + self._data = iterable._data + self._offsets = iterable._offsets + self._lengths = iterable._lengths + + elif iterable is not None: # Initialize the `CompactList` object from iterable's item. BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. @@ -249,14 +260,8 @@ def __init__(self, streamlines=None, data_per_point=None): self.streamlines = streamlines - - self.data_per_streamline = {} - if data_per_streamline is not None: - self.data_per_streamline = data_per_streamline - - self.data_per_point = {} - if data_per_point is not None: - self.data_per_point = data_per_point + self.data_per_streamline = data_per_streamline + self.data_per_point = data_per_point @property def streamlines(self): @@ -272,6 +277,9 @@ def data_per_streamline(self): @data_per_streamline.setter def data_per_streamline(self, value): + if value is None: + value = {} + self._data_per_streamline = {} for k, v in value.items(): self._data_per_streamline[k] = np.asarray(v) @@ -282,6 +290,9 @@ def data_per_point(self): @data_per_point.setter def data_per_point(self, value): + if value is None: + value = {} + self._data_per_point = {} for k, v in value.items(): self._data_per_point[k] = CompactList(v) @@ -538,6 +549,16 @@ def get_magic_number(cls): ''' Returns streamlines file's magic number. ''' raise NotImplementedError() + @classmethod + def support_data_per_point(cls): + ''' Tells if this tractogram format supports saving data per point. ''' + raise NotImplementedError() + + @classmethod + def support_data_per_streamline(cls): + ''' Tells if this tractogram format supports saving data per streamline. ''' + raise NotImplementedError() + @classmethod def is_correct_format(cls, fileobj): ''' Checks if the file has the right streamlines file format. diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 706afd348a..2298e395b2 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -19,7 +19,7 @@ class Field: DIMENSIONS = "dimensions" MAGIC_NUMBER = "magic_number" ORIGIN = "origin" - to_world_space = "to_world_space" + VOXEL_TO_RASMM = "voxel_to_rasmm" VOXEL_ORDER = "voxel_order" ENDIAN = "endian" @@ -38,7 +38,7 @@ def to_world_space(self): @to_world_space.setter def to_world_space(self, value): - self._to_world_space = np.array(value, dtype=np.float32) + self._to_world_space = np.asarray(value, dtype=np.float32) @property def voxel_sizes(self): diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_base_format.py index 5ae8c0d5d8..53d8589ab7 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_base_format.py @@ -293,6 +293,15 @@ def test_tractogram_creation(self): assert_raises(ValueError, Tractogram, self.streamlines, data_per_point=data_per_point) + # Inconsistent number of scalars between streamlines + wrong_data = [[(1, 0, 0)]*1, + [(0, 1)]*2, + [(0, 0, 1)]*5] + + data_per_point = {'wrong_data': wrong_data} + assert_raises(ValueError, Tractogram, self.streamlines, + data_per_point=data_per_point) + def test_tractogram_getter(self): # Tractogram with only streamlines tractogram = Tractogram(streamlines=self.streamlines) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index e78be90429..5dabad1815 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -20,19 +20,20 @@ DATA_PATH = pjoin(os.path.dirname(__file__), 'data') -def check_tractogram(tractogram, nb_streamlines, streamlines, scalars, properties): +def check_tractogram(tractogram, nb_streamlines, streamlines, data_per_streamline, data_per_point): # Check data assert_equal(len(tractogram), nb_streamlines) assert_arrays_equal(tractogram.streamlines, streamlines) - assert_arrays_equal(tractogram.scalars, scalars) - assert_arrays_equal(tractogram.properties, properties) - assert_true(isiterable(tractogram)) - assert_equal(tractogram.header.nb_streamlines, nb_streamlines) - nb_scalars_per_point = 0 if len(scalars) == 0 else len(scalars[0][0]) - nb_properties_per_streamline = 0 if len(properties) == 0 else len(properties[0]) - assert_equal(tractogram.header.nb_scalars_per_point, nb_scalars_per_point) - assert_equal(tractogram.header.nb_properties_per_streamline, nb_properties_per_streamline) + for key in data_per_streamline.keys(): + assert_arrays_equal(tractogram.data_per_streamline[key], + data_per_streamline[key]) + + for key in data_per_point.keys(): + assert_arrays_equal(tractogram.data_per_point[key], + data_per_point[key]) + + assert_true(isiterable(tractogram)) def test_is_supported(): @@ -127,6 +128,9 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] + self.data_per_point = {'scalars': self.colors} + self.data_per_streamline = {'properties': self.mean_curvature_torsion} + self.nb_streamlines = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) @@ -135,91 +139,94 @@ def setUp(self): def test_load_empty_file(self): for empty_filename in self.empty_filenames: tractogram_file = nib.streamlines.load(empty_filename, - lazy_load=False, - ref=self.to_world_space) + lazy_load=False) assert_true(isinstance(tractogram_file, TractogramFile)) assert_true(type(tractogram_file.tractogram), Tractogram) - check_tractogram(tractogram_file.tractogram, 0, [], [], []) + check_tractogram(tractogram_file.tractogram, 0, [], {}, {}) def test_load_simple_file(self): for simple_filename in self.simple_filenames: tractogram_file = nib.streamlines.load(simple_filename, - lazy_load=False, - ref=self.to_world_space) + lazy_load=False) assert_true(isinstance(tractogram_file, TractogramFile)) assert_true(type(tractogram_file.tractogram), Tractogram) check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, [], []) + self.streamlines, {}, {}) - # Test lazy_load - tractogram_file = nib.streamlines.load(simple_filename, - lazy_load=True, - ref=self.to_world_space) + # # Test lazy_load + # tractogram_file = nib.streamlines.load(simple_filename, + # lazy_load=True) - assert_true(isinstance(tractogram_file, TractogramFile)) - assert_true(type(tractogram_file.tractogram), LazyTractogram) - check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, [], []) + # assert_true(isinstance(tractogram_file, TractogramFile)) + # assert_true(type(tractogram_file.tractogram), LazyTractogram) + # check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + # self.streamlines, {}, {}) def test_load_complex_file(self): for complex_filename in self.complex_filenames: file_format = nib.streamlines.detect_format(complex_filename) - scalars = [] - if file_format.can_save_scalars(): - scalars = self.colors + data_per_point = {} + if file_format.support_data_per_point(): + data_per_point = {'scalars': self.colors} - properties = [] - if file_format.can_save_properties(): - properties = self.mean_curvature_torsion + data_per_streamline = [] + if file_format.support_data_per_streamline(): + data_per_streamline = {'properties': self.mean_curvature_torsion} tractogram_file = nib.streamlines.load(complex_filename, - lazy_load=False, - ref=self.to_world_space) + lazy_load=False) assert_true(isinstance(tractogram_file, TractogramFile)) assert_true(type(tractogram_file.tractogram), Tractogram) check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, scalars, properties) - - # Test lazy_load - tractogram_file = nib.streamlines.load(complex_filename, - lazy_load=True, - ref=self.to_world_space) - assert_true(isinstance(tractogram_file, TractogramFile)) - assert_true(type(tractogram_file.tractogram), LazyTractogram) - check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, scalars, properties) + self.streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + + # # Test lazy_load + # tractogram_file = nib.streamlines.load(complex_filename, + # lazy_load=True) + # assert_true(isinstance(tractogram_file, TractogramFile)) + # assert_true(type(tractogram_file.tractogram), LazyTractogram) + # check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + # self.streamlines, + # data_per_streamline=data_per_streamline, + # data_per_point=data_per_point) def test_save_simple_file(self): tractogram = Tractogram(self.streamlines) - for ext in nib.streamlines.FORMATS.keys(): + for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: - nib.streamlines.save(tractogram, f.name) - loaded_tractogram = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) - check_tractogram(loaded_tractogram, self.nb_streamlines, - self.streamlines, [], []) + nib.streamlines.save_tractogram(tractogram, f.name) + tractogram_file = nib.streamlines.load(f, lazy_load=False) + check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + self.streamlines, {}, {}) def test_save_complex_file(self): - tractogram = Tractogram(self.streamlines, scalars=self.colors, properties=self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, + data_per_streamline=self.data_per_streamline, + data_per_point=self.data_per_point) for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: with clear_and_catch_warnings(record=True, modules=[trk]) as w: - nib.streamlines.save(tractogram, f.name) + nib.streamlines.save_tractogram(tractogram, f.name) - # If streamlines format does not support saving scalars or - # properties, a warning message should be issued. - if not (cls.can_save_scalars() and cls.can_save_properties()): + # If streamlines format does not support saving data per point + # or data per streamline, a warning message should be issued. + if not (cls.support_data_per_point() and cls.support_data_per_streamline()): assert_equal(len(w), 1) assert_true(issubclass(w[0].category, UsageWarning)) - scalars = [] - if cls.can_save_scalars(): - scalars = self.colors + data_per_point = {} + if cls.support_data_per_point(): + data_per_point = self.data_per_point - properties = [] - if cls.can_save_properties(): - properties = self.mean_curvature_torsion + data_per_streamline = [] + if cls.support_data_per_streamline(): + data_per_streamline = self.data_per_streamline - loaded_tractogram = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) - check_tractogram(loaded_tractogram, self.nb_streamlines, - self.streamlines, scalars, properties) + tractogram_file = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) + check_tractogram(tractogram_file.tractogram, self.nb_streamlines, + self.streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index e1584731db..1b1c2ab7a6 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -5,7 +5,7 @@ from nibabel.externals.six import BytesIO from nibabel.testing import suppress_warnings, clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal, assert_tractogram_equal, isiterable +from nibabel.testing import assert_arrays_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true from .. import base_format @@ -18,19 +18,36 @@ DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') -def check_tractogram(tractogram, nb_streamlines, streamlines, scalars, properties): +def assert_tractogram_equal(t1, t2): + assert_equal(len(t1), len(t2)) + assert_arrays_equal(t1.streamlines, t2.streamlines) + + assert_equal(len(t1.data_per_streamline), len(t2.data_per_streamline)) + for key in t1.data_per_streamline.keys(): + assert_arrays_equal(t1.data_per_streamline[key], + t2.data_per_streamline[key]) + + assert_equal(len(t1.data_per_point), len(t2.data_per_point)) + for key in t1.data_per_point.keys(): + assert_arrays_equal(t1.data_per_point[key], + t2.data_per_point[key]) + + + +def check_tractogram(tractogram, nb_streamlines, streamlines, data_per_streamline, data_per_point): # Check data assert_equal(len(tractogram), nb_streamlines) assert_arrays_equal(tractogram.streamlines, streamlines) - assert_arrays_equal(tractogram.scalars, scalars) - assert_arrays_equal(tractogram.properties, properties) - assert_true(isiterable(tractogram)) - assert_equal(tractogram.header.nb_streamlines, nb_streamlines) - nb_scalars_per_point = 0 if len(scalars) == 0 else len(scalars[0][0]) - nb_properties_per_streamline = 0 if len(properties) == 0 else len(properties[0]) - assert_equal(tractogram.header.nb_scalars_per_point, nb_scalars_per_point) - assert_equal(tractogram.header.nb_properties_per_streamline, nb_properties_per_streamline) + for key in data_per_streamline.keys(): + assert_arrays_equal(tractogram.data_per_streamline[key], + data_per_streamline[key]) + + for key in data_per_point.keys(): + assert_arrays_equal(tractogram.data_per_point[key], + data_per_point[key]) + + assert_true(isiterable(tractogram)) class TestTRK(unittest.TestCase): @@ -54,35 +71,40 @@ def setUp(self): np.array([2.11, 2.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] + self.data_per_point = {'scalars': self.colors} + self.data_per_streamline = {'properties': self.mean_curvature_torsion} + self.nb_streamlines = len(self.streamlines) self.nb_scalars_per_point = self.colors[0].shape[1] self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) self.affine = np.eye(4) def test_load_empty_file(self): - trk = TrkFile.load(self.empty_trk_filename, ref=None, lazy_load=False) - check_tractogram(trk.tractogram, 0, [], [], []) + trk = TrkFile.load(self.empty_trk_filename, lazy_load=False) + check_tractogram(trk.tractogram, 0, [], {}, {}) - trk = TrkFile.load(self.empty_trk_filename, ref=None, lazy_load=True) - # Suppress warning about loading a TRK file in lazy mode with count=0. - with suppress_warnings(): - check_tractogram(trk.tractogram, 0, [], [], []) + # trk = TrkFile.load(self.empty_trk_filename, lazy_load=True) + # # Suppress warning about loading a TRK file in lazy mode with count=0. + # with suppress_warnings(): + # check_tractogram(trk.tractogram, 0, [], [], []) def test_load_simple_file(self): - trk = TrkFile.load(self.simple_trk_filename, ref=None, lazy_load=False) - check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) + trk = TrkFile.load(self.simple_trk_filename, lazy_load=False) + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, {}, {}) - trk = TrkFile.load(self.simple_trk_filename, ref=None, lazy_load=True) - check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) + # trk = TrkFile.load(self.simple_trk_filename, lazy_load=True) + # check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) def test_load_complex_file(self): - trk = TrkFile.load(self.complex_trk_filename, ref=None, lazy_load=False) + trk = TrkFile.load(self.complex_trk_filename, lazy_load=False) check_tractogram(trk.tractogram, self.nb_streamlines, - self.streamlines, self.colors, self.mean_curvature_torsion) + self.streamlines, + data_per_point=self.data_per_point, + data_per_streamline=self.data_per_streamline) - trk = TrkFile.load(self.complex_trk_filename, ref=None, lazy_load=True) - check_tractogram(trk.tractogram, self.nb_streamlines, - self.streamlines, self.colors, self.mean_curvature_torsion) + # trk = TrkFile.load(self.complex_trk_filename, lazy_load=True) + # check_tractogram(trk.tractogram, self.nb_streamlines, + # self.streamlines, self.colors, self.mean_curvature_torsion) def test_load_file_with_wrong_information(self): trk_file = open(self.simple_trk_filename, 'rb').read() @@ -91,19 +113,19 @@ def test_load_file_with_wrong_information(self): count = np.array(0, dtype="int32").tostring() new_trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) - check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) + check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, {}, {}) - trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) - with clear_and_catch_warnings(record=True, modules=[base_format]) as w: - check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) - assert_equal(len(w), 1) - assert_true(issubclass(w[0].category, UsageWarning)) + # trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) + # with clear_and_catch_warnings(record=True, modules=[base_format]) as w: + # check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, {}, {}) + # assert_equal(len(w), 1) + # assert_true(issubclass(w[0].category, UsageWarning)) # Simulate a TRK file where `voxel_order` was not provided. voxel_order = np.zeros(1, dtype="|S3").tostring() new_trk_file = trk_file[:948] + voxel_order + trk_file[948+3:] with clear_and_catch_warnings(record=True, modules=[trk]) as w: - TrkFile.load(BytesIO(new_trk_file), ref=None) + TrkFile.load(BytesIO(new_trk_file)) assert_equal(len(w), 1) assert_true(issubclass(w[0].category, HeaderWarning)) assert_true("LPS" in str(w[0].message)) @@ -128,7 +150,7 @@ def test_write_simple_file(self): loaded_trk = TrkFile.load(trk_file) check_tractogram(loaded_trk.tractogram, self.nb_streamlines, - self.streamlines, [], []) + self.streamlines, {}, {}) loaded_trk_orig = TrkFile.load(self.simple_trk_filename) assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) @@ -138,7 +160,8 @@ def test_write_simple_file(self): def test_write_complex_file(self): # With scalars - tractogram = Tractogram(self.streamlines, scalars=self.colors) + tractogram = Tractogram(self.streamlines, + data_per_point=self.data_per_point) trk_file = BytesIO() trk = TrkFile(tractogram, ref=self.affine) @@ -146,23 +169,28 @@ def test_write_complex_file(self): trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file, lazy_load=False) - check_tractogram(loaded_trk.tractogram, self.nb_streamlines, - self.streamlines, self.colors, []) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, self.streamlines, + data_per_streamline={}, + data_per_point=self.data_per_point) # With properties - tractogram = Tractogram(self.streamlines, properties=self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, + data_per_streamline=self.data_per_streamline) - trk_file = BytesIO() trk = TrkFile(tractogram, ref=self.affine) + trk_file = BytesIO() trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file, lazy_load=False) - check_tractogram(loaded_trk.tractogram, self.nb_streamlines, - self.streamlines, [], self.mean_curvature_torsion) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, self.streamlines, + data_per_streamline=self.data_per_streamline, + data_per_point={}) # With scalars and properties - tractogram = Tractogram(self.streamlines, scalars=self.colors, properties=self.mean_curvature_torsion) + tractogram = Tractogram(self.streamlines, + data_per_point=self.data_per_point, + data_per_streamline=self.data_per_streamline) trk_file = BytesIO() trk = TrkFile(tractogram, ref=self.affine) @@ -170,8 +198,9 @@ def test_write_complex_file(self): trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file, lazy_load=False) - check_tractogram(loaded_trk.tractogram, self.nb_streamlines, - self.streamlines, self.colors, self.mean_curvature_torsion) + check_tractogram(loaded_trk.tractogram, self.nb_streamlines, self.streamlines, + data_per_streamline=self.data_per_streamline, + data_per_point=self.data_per_point) loaded_trk_orig = TrkFile.load(self.complex_trk_filename) assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) @@ -185,7 +214,8 @@ def test_write_erroneous_file(self): [(0, 1, 0)], [(0, 0, 1)]] - tractogram = Tractogram(self.streamlines, scalars) + tractogram = Tractogram(self.streamlines, + data_per_point={'scalars': scalars}) trk = TrkFile(tractogram, ref=self.affine) assert_raises(DataError, trk.save, BytesIO()) @@ -193,9 +223,10 @@ def test_write_erroneous_file(self): scalars = [[(1, 0, 0)]*1, [(0, 1, 0)]*2] - tractogram = Tractogram(self.streamlines, scalars) + tractogram = Tractogram(self.streamlines, + data_per_point={'scalars': scalars}) trk = TrkFile(tractogram, ref=self.affine) - assert_raises(DataError, trk.save, BytesIO()) + assert_raises(IndexError, trk.save, BytesIO()) # # Unit test moved to test_base_format.py # # Inconsistent number of scalars between points @@ -206,31 +237,22 @@ def test_write_erroneous_file(self): # tractogram = Tractogram(self.streamlines, scalars) # assert_raises(ValueError, TrkFile.save, tractogram, BytesIO()) - # # Unit test moved to test_base_format.py - # # Inconsistent number of scalars between streamlines - # scalars = [[(1, 0, 0)]*1, - # [(0, 1)]*2, - # [(0, 0, 1)]*5] - - # tractogram = Tractogram(self.streamlines, scalars) - # assert_raises(DataError, TrkFile.save, tractogram, BytesIO()) - - # Unit test moved to test_base_format.py # Inconsistent number of properties properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - tractogram = Tractogram(self.streamlines, properties=properties) + tractogram = Tractogram(self.streamlines, + data_per_streamline={'properties': properties}) trk = TrkFile(tractogram, ref=self.affine) assert_raises(DataError, trk.save, BytesIO()) - # Unit test moved to test_base_format.py # No properties for every streamlines properties = [np.array([1.11, 1.22], dtype="f4"), np.array([2.11, 2.22], dtype="f4")] - tractogram = Tractogram(self.streamlines, properties=properties) + tractogram = Tractogram(self.streamlines, + data_per_streamline={'properties': properties}) trk = TrkFile(tractogram, ref=self.affine) - assert_raises(DataError, trk.save, BytesIO()) + assert_raises(IndexError, trk.save, BytesIO()) # def test_write_file_lazy_tractogram(self): # streamlines = lambda: (point for point in self.streamlines) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 61862f7a0e..b93a65f5d9 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -59,7 +59,7 @@ ('scalar_name', 'S20', 10), (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), ('property_name', 'S20', 10), - (Field.to_world_space, 'f4', (4, 4)), # new field for version 2 + (Field.VOXEL_TO_RASMM, 'f4', (4, 4)), # new field for version 2 ('reserved', 'S444'), (Field.VOXEL_ORDER, 'S4'), ('pad2', 'S4'), @@ -140,16 +140,6 @@ def __init__(self, fileobj): # Keep the file position where the data begin. self.offset_data = f.tell() - # if f.name is not None and self.header[Field.NB_STREAMLINES] > 0: - # filesize = os.path.getsize(f.name) - self.offset_data - # # Remove properties - # filesize -= self.header[Field.NB_STREAMLINES] * self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * 4. - # # Remove the points count at the beginning of each streamline. - # filesize -= self.header[Field.NB_STREAMLINES] * 4. - # # Get nb points. - # nb_points = filesize / ((3 + self.header[Field.NB_SCALARS_PER_POINT]) * 4.) - # self.header[Field.NB_POINTS] = int(nb_points) - def __iter__(self): i4_dtype = np.dtype(self.endianness + "i4") f4_dtype = np.dtype(self.endianness + "f4") @@ -214,11 +204,11 @@ def create_empty_header(cls): header[Field.MAGIC_NUMBER] = TrkFile.MAGIC_NUMBER header[Field.VOXEL_SIZES] = (1, 1, 1) header[Field.DIMENSIONS] = (1, 1, 1) - header[Field.to_world_space] = np.eye(4) + header[Field.VOXEL_TO_RASMM] = np.eye(4) header['version'] = 2 header['hdr_size'] = TrkFile.HEADER_SIZE - return header + return header[0] def __init__(self, fileobj, header): self.header = self.create_empty_header() @@ -235,7 +225,7 @@ def __init__(self, fileobj, header): self.header[Field.NB_SCALARS_PER_POINT] = header.nb_scalars_per_point self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = header.nb_properties_per_streamline self.header[Field.VOXEL_SIZES] = header.voxel_sizes - self.header[Field.to_world_space] = header.to_world_space + self.header[Field.VOXEL_TO_RASMM] = header.to_world_space self.header[Field.VOXEL_ORDER] = header.voxel_order # Keep counts for correcting incoherent fields or warn. @@ -248,27 +238,34 @@ def __init__(self, fileobj, header): self.file = Opener(fileobj, mode="wb") # Keep track of the beginning of the header. self.beginning = self.file.tell() - self.file.write(self.header[0].tostring()) + self.file.write(self.header.tostring()) def write(self, tractogram): i4_dtype = np.dtype("i4") f4_dtype = np.dtype("f4") + # TRK's streamlines need to be in 'voxelmm' space and by definition + # tractogram streamlines are in RAS+ and mm space. + affine = np.linalg.inv(self.header[Field.VOXEL_TO_RASMM]) + affine[range(3), range(3)] *= self.header[Field.VOXEL_SIZES] + + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas streamlines passed in parameters assume (0,0,0) + # to be the center of the voxel. Thus, streamlines are shifted of + # half a voxel. + affine[:-1, -1] += np.array(self.header[Field.VOXEL_SIZES])/2. + + tractogram.apply_affine(affine) + for t in tractogram: - if len(t.scalars) > 0 and len(t.scalars) != len(t.streamline): + if any((len(d) != len(t.streamline) for d in t.data_for_points.values())): raise DataError("Missing scalars for some points!") points = np.asarray(t.streamline, dtype=f4_dtype) - scalars = np.asarray(t.scalars, dtype=f4_dtype).reshape((len(points), -1)) - properties = np.asarray(t.properties, dtype=f4_dtype) - - # TRK's streamlines need to be in 'voxelmm' space - points = points * self.header[Field.VOXEL_SIZES] - # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas streamlines passed in parameters assume (0,0,0) - # to be the center of the voxel. Thus, streamlines are shifted of - # half a voxel. - points += np.array(self.header[Field.VOXEL_SIZES])/2. + keys = sorted(t.data_for_points.keys()) + scalars = np.asarray([t.data_for_points[k] for k in keys], dtype=f4_dtype).reshape((len(points), -1)) + keys = sorted(t.data_for_streamline.keys()) + properties = np.asarray([t.data_for_streamline[k] for k in keys], dtype=f4_dtype).flatten() data = struct.pack(i4_dtype.str[:-1], len(points)) data += np.concatenate((points, scalars), axis=1).tostring() @@ -298,7 +295,91 @@ def write(self, tractogram): # Overwrite header with updated one. self.file.seek(self.beginning, os.SEEK_SET) - self.file.write(self.header[0].tostring()) + self.file.write(self.header.tostring()) + + +import itertools +from nibabel.streamlines.base_format import CompactList + +def create_compactlist_from_generator(gen): + BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. + + streamlines = CompactList() + scalars = CompactList() + properties = np.array([]) + + gen = iter(gen) + try: + first_element = next(gen) + gen = itertools.chain([first_element], gen) + except StopIteration: + return streamlines, scalars, properties + + # Allocated some buffer memory. + pts = np.asarray(first_element[0]) + scals = np.asarray(first_element[1]) + props = np.asarray(first_element[2]) + + scals_shape = scals.shape + props_shape = props.shape + + streamlines._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) + scalars._data = np.empty((BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) + properties = np.empty((BUFFER_SIZE, props.shape[0]), dtype=props.dtype) + + offset = 0 + for i, (pts, scals, props) in enumerate(gen): + pts = np.asarray(pts) + scals = np.asarray(scals) + props = np.asarray(props) + + if scals.shape[1] != scals_shape[1]: + raise ValueError("Number of scalars differs from one" + " point or streamline to another") + + if props.shape != props_shape: + raise ValueError("Number of properties differs from one" + " streamline to another") + + end = offset + len(pts) + if end >= len(streamlines._data): + # Resize is needed (at least `len(pts)` items will be added). + streamlines._data.resize((len(streamlines._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) + scalars._data.resize((len(scalars._data) + len(scals)+BUFFER_SIZE, scals.shape[1])) + + streamlines._offsets.append(offset) + streamlines._lengths.append(len(pts)) + streamlines._data[offset:offset+len(pts)] = pts + scalars._data[offset:offset+len(scals)] = scals + + offset += len(pts) + + if i >= len(properties): + properties.resize((len(properties) + BUFFER_SIZE, props.shape[0])) + + properties[i] = props + + # Clear unused memory. + streamlines._data.resize((offset, pts.shape[1])) + + if scals_shape[1] == 0: + # Because resizing an empty ndarray creates memory! + scalars._data = np.empty((offset, scals.shape[1])) + else: + scalars._data.resize((offset, scals.shape[1])) + + # Share offsets and lengths between streamlines and scalars. + scalars._offsets = streamlines._offsets + scalars._lengths = streamlines._lengths + + if props_shape[0] == 0: + # Because resizing an empty ndarray creates memory! + properties = np.empty((i+1, props.shape[0])) + else: + properties.resize((i+1, props.shape[0])) + + return streamlines, scalars, properties + class TrkFile(TractogramFile): @@ -319,19 +400,19 @@ class TrkFile(TractogramFile): MAGIC_NUMBER = b"TRACK" HEADER_SIZE = 1000 - def __init__(self, tractogram, ref, header=None): + def __init__(self, tractogram, header=None, ref=np.eye(4)): """ Parameters ---------- tractogram : ``Tractogram`` object Tractogram that will be contained in this ``TrkFile``. - ref : filename | `Nifti1Image` object | 2D array (4,4) - Reference space where streamlines live in. - - header : ``TractogramHeader`` file + header : ``TractogramHeader`` file (optional) Metadata associated to this tractogram file. + ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) + Reference space where streamlines live in. + Notes ----- Streamlines of the tractogram are assumed to be in *RAS+* and *mm* space @@ -345,6 +426,16 @@ def get_magic_number(cls): ''' Return TRK's magic number. ''' return cls.MAGIC_NUMBER + @classmethod + def support_data_per_point(cls): + ''' Tells if this tractogram format supports saving data per point. ''' + return True + + @classmethod + def support_data_per_streamline(cls): + ''' Tells if this tractogram format supports saving data per streamline. ''' + return True + @classmethod def is_correct_format(cls, fileobj): ''' Check if the file is in TRK format. @@ -369,7 +460,7 @@ def is_correct_format(cls, fileobj): return False @classmethod - def load(cls, fileobj, lazy_load=False, ref=None): + def load(cls, fileobj, lazy_load=False): ''' Loads streamlines from a file-like object. Parameters @@ -383,9 +474,6 @@ def load(cls, fileobj, lazy_load=False, ref=None): Load streamlines in a lazy manner i.e. they will not be kept in memory. - ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) - Reference space where streamlines live in `fileobj`. - Returns ------- trk_file : ``TrkFile`` object @@ -407,7 +495,7 @@ def load(cls, fileobj, lazy_load=False, ref=None): # If voxel order implied from the affine does not match the voxel # order save in the TRK header, change the orientation. header_ornt = trk_reader.header[Field.VOXEL_ORDER] - affine_ornt = "".join(nib.orientations.aff2axcodes(trk_reader.header[Field.to_world_space])) + affine_ornt = "".join(nib.orientations.aff2axcodes(trk_reader.header[Field.VOXEL_TO_RASMM])) header_ornt = nib.orientations.axcodes2ornt(header_ornt) affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) ornt = nib.orientations.ornt_transform(header_ornt, affine_ornt) @@ -415,7 +503,7 @@ def load(cls, fileobj, lazy_load=False, ref=None): affine = np.dot(M, affine) # Applied the affine going from voxel space to rasmm. - affine = np.dot(trk_reader.header[Field.to_world_space], affine) + affine = np.dot(trk_reader.header[Field.VOXEL_TO_RASMM], affine) # TrackVis considers coordinate (0,0,0) to be the corner of the # voxel whereas streamlines returned assume (0,0,0) to be the @@ -424,6 +512,7 @@ def load(cls, fileobj, lazy_load=False, ref=None): affine[:-1, -1] -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. if lazy_load: + # TODO when LazyTractogram has been refactored. def _apply_transform(trk_reader): for pts, scals, props in trk_reader: # TRK's streamlines are in 'voxelmm' space, we send them to voxel space. @@ -445,73 +534,32 @@ def _apply_transform(trk_reader): if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: tractogram.properties = lambda: [] - # elif Field.NB_POINTS in trk_reader.header: - # # 'count' field is provided, we can avoid creating list of numpy - # # arrays (more memory efficient). - - # nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] - # nb_points = trk_reader.header[Field.NB_POINTS] - - # points = CompactList() - # points._data = np.empty((nb_points, 3), dtype=np.float32) - - # scalars = CompactList() - # scalars._data = np.empty((nb_points, trk_reader.header[Field.NB_SCALARS_PER_POINT]), - # dtype=np.float32) - - # properties = CompactList() - # properties._data = np.empty((nb_streamlines, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]), - # dtype=np.float32) - - # offset = 0 - # offsets = [] - # lengths = [] - # for i, (pts, scals, props) in enumerate(trk_reader): - # offsets.append(offset) - # lengths.append(len(pts)) - # points._data[offset:offset+len(pts)] = pts - # scalars._data[offset:offset+len(scals)] = scals - # properties._data[i] = props - # offset += len(pts) - - # points.offsets = offsets - # scalars.offsets = offsets - # points.lengths = lengths - # scalars.lengths = lengths - - # streamlines = Tractogram(points, scalars, properties) - # streamlines.apply_affine(affine) - - # # Overwrite scalars and properties if there is none - # if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: - # streamlines.scalars = [] - # if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: - # streamlines.properties = [] - else: - tractogram = Tractogram.create_from_generator(trk_reader) - #tractogram = Tractogram(*zip(*trk_reader)) - tractogram.apply_affine(affine) - - # Overwrite scalars and properties if there is none - if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: - tractogram.scalars = [] - if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: - tractogram.properties = [] - - # Set available common information about streamlines in the header - tractogram.header.to_world_space = affine - - # If 'count' field is 0, i.e. not provided, we don't set `nb_streamlines` - if trk_reader.header[Field.NB_STREAMLINES] > 0: - tractogram.header.nb_streamlines = trk_reader.header[Field.NB_STREAMLINES] - - # Keep extra information about TRK format - tractogram.header.extra = trk_reader.header + streamlines, scalars, properties = create_compactlist_from_generator(trk_reader) + tractogram = Tractogram(streamlines) + + if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: + if len(trk_reader.header['scalar_name'][0]) > 0: + for i in range(trk_reader.header[Field.NB_SCALARS_PER_POINT]): + clist = CompactList() + clist._data = scalars._data[:, i] + clist._offsets = scalars._offsets + clist._lengths = scalars._lengths + tractogram.data_per_point[trk_reader.header['scalar_name'][i]] = clist + else: + tractogram.data_per_point['scalars'] = scalars + + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + if len(trk_reader.header['property_name'][0]) > 0: + for i in range(trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]): + tractogram.data_per_streamline[trk_reader.header['property_name'][i]] = properties[:, i] + else: + tractogram.data_per_streamline['properties'] = properties + + # Bring tractogram to RAS+ and mm space + tractogram.apply_affine(affine) ## Perform some integrity checks - #if trk_reader.header[Field.VOXEL_ORDER] != tractogram.header.voxel_order: - # raise HeaderError("'voxel_order' does not match the affine.") #if tractogram.header.voxel_sizes != trk_reader.header[Field.VOXEL_SIZES]: # raise HeaderError("'voxel_sizes' does not match the affine.") #if tractogram.header.nb_scalars_per_point != trk_reader.header[Field.NB_SCALARS_PER_POINT]: @@ -532,14 +580,8 @@ def save(self, fileobj): of the TRK header data). ''' # Update header using the tractogram. - self.header.nb_scalars_per_point = 0 - if self.tractogram.scalars.shape is not None: - self.header.nb_scalars_per_point = len(self.tractogram.scalars[0]) - - self.header.nb_properties_per_streamline = 0 - if self.tractogram.properties.shape is not None: - self.header.nb_properties_per_streamline = len(self.tractogram.properties[0]) - + self.header.nb_scalars_per_point = sum(map(lambda e: len(e[0]), self.tractogram.data_per_point.values())) + self.header.nb_properties_per_streamline = sum(map(lambda e: len(e[0]), self.tractogram.data_per_streamline.values())) trk_writer = TrkWriter(fileobj, self.header) trk_writer.write(self.tractogram) @@ -572,7 +614,7 @@ def pretty_print(fileobj): info += "scalar_name:\n {0}".format("\n".join(hdr['scalar_name'])) info += "nb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) info += "property_name:\n {0}".format("\n".join(hdr['property_name'])) - info += "vox_to_world: {0}".format(hdr[Field.to_world_space]) + info += "vox_to_world: {0}".format(hdr[Field.VOXEL_TO_RASMM]) info += "voxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) info += "image_orientation_patient: {0}".format(hdr['image_orientation_patient']) info += "pad1: {0}".format(hdr['pad1']) From 1872d4f94767fb7fffc0467de22dd7fd99b33f34 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 3 Nov 2015 08:38:16 -0500 Subject: [PATCH 021/135] Refactored LazyTractogram --- nibabel/streamlines/base_format.py | 126 +++++++++++++++++++---------- 1 file changed, 84 insertions(+), 42 deletions(-) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 9276f01f2c..5d5cd0e6e0 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -322,7 +322,6 @@ def __len__(self): def copy(self): """ Returns a copy of this `Tractogram` object. """ - new_data_per_streamline = {} for key in self.data_per_streamline: new_data_per_streamline[key] = self.data_per_streamline[key].copy() @@ -390,10 +389,12 @@ class LazyTractogram(Tractogram): If provided, ``scalars`` and ``properties`` must yield the same number of values as ``streamlines``. ''' - def __init__(self, streamlines_func=lambda:[], scalars_func=lambda: [], properties_func=lambda: [], getitem_func=None): - super(LazyTractogram, self).__init__(streamlines_func, scalars_func, properties_func) - self._data = lambda: zip_longest(self.streamlines, self.scalars, self.properties, fillvalue=[]) - self._getitem = getitem_func + def __init__(self, streamlines=lambda:[], data_per_streamline=None, data_per_point=None): + super(LazyTractogram, self).__init__(streamlines, data_per_streamline, data_per_point) + self.nb_streamlines = None + self._data = None + self._getitem = None + self._affine_to_apply = np.eye(4) @classmethod def create_from_data(cls, data_func): @@ -422,6 +423,13 @@ def create_from_data(cls, data_func): @property def streamlines(self): + if not np.all(self._affine_to_apply == np.eye(4)): + def _transform(): + for s in self._streamlines(): + yield apply_affine(self._affine_to_apply, s) + + return _transform() + return self._streamlines() @streamlines.setter @@ -432,34 +440,70 @@ def streamlines(self, value): self._streamlines = value @property - def scalars(self): - return self._scalars() + def data_per_streamline(self): + return self._data_per_streamline - @scalars.setter - def scalars(self, value): - if not callable(value): - raise TypeError("`scalars` must be a coroutine.") + @data_per_streamline.setter + def data_per_streamline(self, value): + if value is None: + value = {} - self._scalars = value - self.header.nb_scalars_per_point = 0 - scalars = pop(self.scalars) - if scalars is not None and len(scalars) > 0: - self.header.nb_scalars_per_point = len(scalars[0]) + self._data_per_streamline = {} + for k, v in value.items(): + if not callable(v): + raise TypeError("`data_per_streamline` must be a dict of coroutines.") + + self._data_per_streamline[k] = v @property - def properties(self): - return self._properties() + def data_per_point(self): + return self._data_per_point - @properties.setter - def properties(self, value): + @data_per_point.setter + def data_per_point(self, value): + if value is None: + value = {} + + self._data_per_point = {} + for k, v in value.items(): + if not callable(v): + raise TypeError("`data_per_point` must be a dict of coroutines.") + + self._data_per_point[k] = v + + @property + def data(self): + if self._data is not None: + return self._data() + + def _gen_data(): + data_per_streamline_generators = {} + for k, v in self.data_per_streamline.items(): + data_per_streamline_generators[k] = iter(v()) + + data_per_point_generators = {} + for k, v in self.data_per_point.items(): + data_per_point_generators[k] = iter(v()) + + for s in self.streamlines: + data_for_streamline = {} + for k, v in data_per_streamline_generators.items(): + data_for_streamline[k] = next(v) + + data_for_points = {} + for k, v in data_per_point_generators.items(): + data_for_points[k] = v() + + yield TractogramItem(s, data_for_streamline, data_for_points) + + return _gen_data() + + @data.setter + def data(self, value): if not callable(value): - raise TypeError("`properties` must be a coroutine.") + raise TypeError("`data` must be a coroutine.") - self._properties = value - self.header.nb_properties_per_streamline = 0 - properties = pop(self.properties) - if properties is not None: - self.header.nb_properties_per_streamline = len(properties) + self._data = value def __getitem__(self, idx): if self._getitem is None: @@ -469,45 +513,43 @@ def __getitem__(self, idx): def __iter__(self): i = 0 - for i, s in enumerate(self._data(), start=1): - yield TractogramItem(*s) + for i, tractogram_item in enumerate(self.data, start=1): + yield tractogram_item # To be safe, update information about number of streamlines. - self.header.nb_streamlines = i + self.nb_streamlines = i def __len__(self): # If length is unknown, we obtain it by iterating through streamlines. - if self.header.nb_streamlines is None: + if self.nb_streamlines is None: warn("Number of streamlines will be determined manually by looping" " through the streamlines. If you know the actual number of" " streamlines, you might want to set it beforehand via" " `self.header.nb_streamlines`." " Note this will consume any generators used to create this" " `LazyTractogram` object.", UsageWarning) - return sum(1 for _ in self) + return sum(1 for _ in self.streamlines) - return self.header.nb_streamlines + return self.nb_streamlines def copy(self): """ Returns a copy of this `LazyTractogram` object. """ - streamlines = LazyTractogram(self._streamlines, self._scalars, self._properties) - streamlines._header = self.header.copy() - return streamlines + tractogram = LazyTractogram(self._streamlines, + self._data_per_streamline, + self._data_per_point) + tractogram.nb_streamlines = self.nb_streamlines + tractogram._data = self._data + return tractogram - def transform(self, affine): + def apply_affine(self, affine): """ Applies an affine transformation on the streamlines. Parameters ---------- affine : 2D array (4,4) Transformation that will be applied on each streamline. - - Returns - ------- - streamlines : `LazyTractogram` object - Tractogram living in a space defined by `affine`. """ - return super(LazyTractogram, self).transform(affine, lazy=True) + self._affine_to_apply = np.dot(affine, self._affine_to_apply) class abstractclassmethod(classmethod): From 1a72cbf585c125e3737680f42585114fee6bb7c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 8 Nov 2015 13:23:58 -0500 Subject: [PATCH 022/135] Added CompactList to init --- nibabel/streamlines/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index cfd8990741..9c24d6f2b0 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,4 +1,5 @@ from .header import TractogramHeader +from .base_format import CompactList from .base_format import Tractogram, LazyTractogram from nibabel.streamlines.trk import TrkFile From a7628023157ba45c083b6c26057d5aaceb400b55 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 8 Nov 2015 13:36:03 -0500 Subject: [PATCH 023/135] Added get_streamlines method to TractogramFile --- nibabel/streamlines/base_format.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index 5d5cd0e6e0..b5c9f3c03f 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -583,6 +583,9 @@ def header(self): def get_tractogram(self): return self.tractogram + def get_streamlines(self): + return self.streamlines + def get_header(self): return self.header From 91837dd975a2f9db15f473f1ddb56d0e54ac72ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 8 Nov 2015 14:07:01 -0500 Subject: [PATCH 024/135] Added save and load support for compact_list --- nibabel/streamlines/tests/test_utils.py | 26 +++++++++++++++++++++++++ nibabel/streamlines/utils.py | 17 ++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/nibabel/streamlines/tests/test_utils.py b/nibabel/streamlines/tests/test_utils.py index 4cceef35c4..b5c67e0770 100644 --- a/nibabel/streamlines/tests/test_utils.py +++ b/nibabel/streamlines/tests/test_utils.py @@ -1,5 +1,6 @@ import os import unittest +import tempfile import numpy as np import nibabel as nib @@ -7,7 +8,9 @@ from numpy.testing import assert_array_equal from nose.tools import assert_equal, assert_raises, assert_true +from ..base_format import CompactList from ..utils import pop, get_affine_from_reference +from ..utils import save_compact_list, load_compact_list def test_peek(): @@ -33,3 +36,26 @@ def test_get_affine_from_reference(): # Get affine from a `SpatialImage` using by its filename. assert_array_equal(get_affine_from_reference(filename), affine) + + +def test_save_and_load_compact_list(): + + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + clist = CompactList() + save_compact_list(f, clist) + f.seek(0, os.SEEK_SET) + loaded_clist = load_compact_list(f) + assert_array_equal(loaded_clist._data, clist._data) + assert_array_equal(loaded_clist._offsets, clist._offsets) + assert_array_equal(loaded_clist._lengths, clist._lengths) + + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + clist = CompactList(data) + save_compact_list(f, clist) + f.seek(0, os.SEEK_SET) + loaded_clist = load_compact_list(f) + assert_array_equal(loaded_clist._data, clist._data) + assert_array_equal(loaded_clist._offsets, clist._offsets) + assert_array_equal(loaded_clist._lengths, clist._lengths) diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 7bbbe1ef8d..a44b25fe1b 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -33,3 +33,20 @@ def pop(iterable): "Returns the next item from the iterable else None" value = list(itertools.islice(iterable, 1)) return value[0] if len(value) > 0 else None + + +def save_compact_list(filename, compact_list): + np.savez(filename, + data=compact_list._data, + offsets=compact_list._offsets, + lengths=compact_list._lengths) + + +def load_compact_list(filename): + from .base_format import CompactList + content = np.load(filename) + compact_list = CompactList() + compact_list._data = content["data"] + compact_list._offsets = content["offsets"].tolist() + compact_list._lengths = content["lengths"].tolist() + return compact_list From 0b99ae9900b1d985eaa1e7e235fe8adf3ce673d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 8 Nov 2015 14:08:12 -0500 Subject: [PATCH 025/135] DOC: load and save utils functions --- nibabel/streamlines/utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index a44b25fe1b..7fbfab5106 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -36,6 +36,7 @@ def pop(iterable): def save_compact_list(filename, compact_list): + """ Saves a `CompactList` object to a .npz file. """ np.savez(filename, data=compact_list._data, offsets=compact_list._offsets, @@ -43,6 +44,7 @@ def save_compact_list(filename, compact_list): def load_compact_list(filename): + """ Loads a `CompactList` object from a .npz file. """ from .base_format import CompactList content = np.load(filename) compact_list = CompactList() From cbdc2a374136e4b5e3a8839ad4c0c0fb3837015f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 11 Nov 2015 10:21:22 -0500 Subject: [PATCH 026/135] Refactored streamlines API and added unit tests for LazyTractogram --- nibabel/streamlines/__init__.py | 4 +- nibabel/streamlines/base_format.py | 640 ------------------ nibabel/streamlines/compact_list.py | 198 ++++++ .../streamlines/tests/test_compact_list.py | 222 ++++++ ...test_base_format.py => test_tractogram.py} | 315 ++------- nibabel/streamlines/tests/test_utils.py | 27 - nibabel/streamlines/tractogram.py | 390 +++++++++++ nibabel/streamlines/tractogram_file.py | 103 +++ nibabel/streamlines/trk.py | 17 +- nibabel/streamlines/utils.py | 19 - 10 files changed, 989 insertions(+), 946 deletions(-) create mode 100644 nibabel/streamlines/compact_list.py create mode 100644 nibabel/streamlines/tests/test_compact_list.py rename nibabel/streamlines/tests/{test_base_format.py => test_tractogram.py} (52%) create mode 100644 nibabel/streamlines/tractogram.py create mode 100644 nibabel/streamlines/tractogram_file.py diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 9c24d6f2b0..abf9ca27f3 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,6 +1,6 @@ from .header import TractogramHeader -from .base_format import CompactList -from .base_format import Tractogram, LazyTractogram +from .compact_list import CompactList +from .tractogram import Tractogram, LazyTractogram from nibabel.streamlines.trk import TrkFile #from nibabel.streamlines.tck import TckFile diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py index b5c9f3c03f..e694e7e0c4 100644 --- a/nibabel/streamlines/base_format.py +++ b/nibabel/streamlines/base_format.py @@ -1,18 +1,3 @@ -import itertools -import numpy as np -from warnings import warn - -from abc import ABCMeta, abstractmethod, abstractproperty - -from nibabel.externals.six.moves import zip_longest -from nibabel.affines import apply_affine - -from .header import TractogramHeader -from .utils import pop - - -class UsageWarning(Warning): - pass class HeaderWarning(Warning): @@ -25,628 +10,3 @@ class HeaderError(Exception): class DataError(Exception): pass - - -class CompactList(object): - """ Class for compacting list of ndarrays with matching shape except for - the first dimension. - """ - def __init__(self, iterable=None): - """ - Parameters - ---------- - iterable : iterable (optional) - If specified, create a ``CompactList`` object initialized from - iterable's items. Otherwise, create an empty ``CompactList``. - - Notes - ----- - If `iterable` is a ``CompactList`` object, a view is returned and no - memory is allocated. For an actual copy use the `.copy()` method. - """ - # Create new empty `CompactList` object. - self._data = None - self._offsets = [] - self._lengths = [] - - if isinstance(iterable, CompactList): - # Create a view. - self._data = iterable._data - self._offsets = iterable._offsets - self._lengths = iterable._lengths - - elif iterable is not None: - # Initialize the `CompactList` object from iterable's item. - BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. - - offset = 0 - for i, e in enumerate(iterable): - e = np.asarray(e) - if i == 0: - self._data = np.empty((BUFFER_SIZE,) + e.shape[1:], dtype=e.dtype) - - end = offset + len(e) - if end >= len(self._data): - # Resize is needed (at least `len(e)` items will be added). - self._data.resize((len(self._data) + len(e)+BUFFER_SIZE,) + self.shape) - - self._offsets.append(offset) - self._lengths.append(len(e)) - self._data[offset:offset+len(e)] = e - offset += len(e) - - # Clear unused memory. - if self._data is not None: - self._data.resize((offset,) + self.shape) - - @property - def shape(self): - """ Returns the matching shape of the elements in this compact list. """ - if self._data is None: - return None - - return self._data.shape[1:] - - def append(self, element): - """ Appends `element` to this compact list. - - Parameters - ---------- - element : ndarray - Element to append. The shape must match already inserted elements - shape except for the first dimension. - - Notes - ----- - If you need to add multiple elements you should consider - `CompactList.extend`. - """ - if self._data is None: - self._data = np.asarray(element).copy() - self._offsets.append(0) - self._lengths.append(len(element)) - return - - if element.shape[1:] != self.shape: - raise ValueError("All dimensions, except the first one, must match exactly") - - self._offsets.append(len(self._data)) - self._lengths.append(len(element)) - self._data = np.append(self._data, element, axis=0) - - def extend(self, elements): - """ Appends all `elements` to this compact list. - - Parameters - ---------- - element : list of ndarrays, ``CompactList`` object - Elements to append. The shape must match already inserted elements - shape except for the first dimension. - """ - if isinstance(elements, CompactList): - self._data = np.concatenate([self._data, elements._data], axis=0) - offset = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 - self._lengths.extend(elements._lengths) - self._offsets.extend(np.cumsum([offset] + elements._lengths).tolist()[:-1]) - else: - self._data = np.concatenate([self._data] + list(elements), axis=0) - offset = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 - lengths = map(len, elements) - self._lengths.extend(lengths) - self._offsets.extend(np.cumsum([offset] + lengths).tolist()[:-1]) - - def copy(self): - """ Creates a copy of this ``CompactList`` object. """ - # We cannot just deepcopy this object since we don't know if it has been created - # using slicing. If it is the case, `self.data` probably contains more data than necessary - # so we copy only elements according to `self._offsets`. - compact_list = CompactList() - total_lengths = np.sum(self._lengths) - compact_list._data = np.empty((total_lengths,) + self._data.shape[1:], dtype=self._data.dtype) - - cur_offset = 0 - for offset, lengths in zip(self._offsets, self._lengths): - compact_list._offsets.append(cur_offset) - compact_list._lengths.append(lengths) - compact_list._data[cur_offset:cur_offset+lengths] = self._data[offset:offset+lengths] - cur_offset += lengths - - return compact_list - - def __getitem__(self, idx): - """ Gets element(s) through indexing. - - Parameters - ---------- - idx : int, slice or list - Index of the element(s) to get. - - Returns - ------- - ndarray object(s) - When `idx` is a int, returns a single ndarray. - When `idx` is either a slice or a list, returns a list of ndarrays. - """ - if isinstance(idx, int) or isinstance(idx, np.integer): - return self._data[self._offsets[idx]:self._offsets[idx]+self._lengths[idx]] - - elif type(idx) is slice: - # TODO: Should we have a CompactListView class that would be - # returned when slicing? - compact_list = CompactList() - compact_list._data = self._data - compact_list._offsets = self._offsets[idx] - compact_list._lengths = self._lengths[idx] - return compact_list - - elif type(idx) is list: - # TODO: Should we have a CompactListView class that would be - # returned when doing advance indexing? - compact_list = CompactList() - compact_list._data = self._data - compact_list._offsets = [self._offsets[i] for i in idx] - compact_list._lengths = [self._lengths[i] for i in idx] - return compact_list - - raise TypeError("Index must be a int or a slice! Not " + str(type(idx))) - - def __iter__(self): - if len(self._lengths) != len(self._offsets): - raise ValueError("CompactList object corrupted: len(self._lengths) != len(self._offsets)") - - for offset, lengths in zip(self._offsets, self._lengths): - yield self._data[offset: offset+lengths] - - def __len__(self): - return len(self._offsets) - - def __repr__(self): - return repr(list(self)) - - -class TractogramItem(object): - """ Class containing information about one streamline. - - ``TractogramItem`` objects have three main properties: `streamline`, - `data_for_streamline`, and `data_for_points`. - - Parameters - ---------- - streamline : ndarray of shape (N, 3) - Points of this streamline represented as an ndarray of shape (N, 3) - where N is the number of points. - - data_for_streamline : dict - - data_for_points : dict - """ - def __init__(self, streamline, data_for_streamline, data_for_points): - self.streamline = np.asarray(streamline) - self.data_for_streamline = data_for_streamline - self.data_for_points = data_for_points - - def __iter__(self): - return iter(self.streamline) - - def __len__(self): - return len(self.streamline) - - -class Tractogram(object): - """ Class containing information about streamlines. - - Tractogram objects have three main properties: ``streamlines`` - - Parameters - ---------- - streamlines : list of ndarray of shape (Nt, 3) - Sequence of T streamlines. One streamline is an ndarray of shape - (Nt, 3) where Nt is the number of points of streamline t. - - data_per_streamline : dictionary of list of ndarray of shape (P,) - Sequence of T ndarrays of shape (P,) where T is the number of - streamlines defined by ``streamlines``, P is the number of properties - associated to each streamline. - - data_per_point : dictionary of list of ndarray of shape (Nt, M) - Sequence of T ndarrays of shape (Nt, M) where T is the number of - streamlines defined by ``streamlines``, Nt is the number of points - for a particular streamline t and M is the number of scalars - associated to each point (excluding the three coordinates). - - """ - def __init__(self, streamlines=None, - data_per_streamline=None, - data_per_point=None): - - self.streamlines = streamlines - self.data_per_streamline = data_per_streamline - self.data_per_point = data_per_point - - @property - def streamlines(self): - return self._streamlines - - @streamlines.setter - def streamlines(self, value): - self._streamlines = CompactList(value) - - @property - def data_per_streamline(self): - return self._data_per_streamline - - @data_per_streamline.setter - def data_per_streamline(self, value): - if value is None: - value = {} - - self._data_per_streamline = {} - for k, v in value.items(): - self._data_per_streamline[k] = np.asarray(v) - - @property - def data_per_point(self): - return self._data_per_point - - @data_per_point.setter - def data_per_point(self, value): - if value is None: - value = {} - - self._data_per_point = {} - for k, v in value.items(): - self._data_per_point[k] = CompactList(v) - - def __iter__(self): - for i in range(len(self.streamlines)): - yield self[i] - - def __getitem__(self, idx): - pts = self.streamlines[idx] - - new_data_per_streamline = {} - for key in self.data_per_streamline: - new_data_per_streamline[key] = self.data_per_streamline[key][idx] - - new_data_per_point = {} - for key in self.data_per_point: - new_data_per_point[key] = self.data_per_point[key][idx] - - if type(idx) is slice: - return Tractogram(pts, new_data_per_streamline, new_data_per_point) - - return TractogramItem(pts, new_data_per_streamline, new_data_per_point) - - def __len__(self): - return len(self.streamlines) - - def copy(self): - """ Returns a copy of this `Tractogram` object. """ - new_data_per_streamline = {} - for key in self.data_per_streamline: - new_data_per_streamline[key] = self.data_per_streamline[key].copy() - - new_data_per_point = {} - for key in self.data_per_point: - new_data_per_point[key] = self.data_per_point[key].copy() - - tractogram = Tractogram(self.streamlines.copy(), - new_data_per_streamline, - new_data_per_point) - return tractogram - - def apply_affine(self, affine): - """ Applies an affine transformation on the points of each streamline. - - This is performed in-place. - - Parameters - ---------- - affine : 2D array (4,4) - Transformation that will be applied on each streamline. - """ - if len(self.streamlines) == 0: - return - - BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. - for i in range(0, len(self.streamlines._data), BUFFER_SIZE): - pts = self.streamlines._data[i:i+BUFFER_SIZE] - self.streamlines._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) - - -class LazyTractogram(Tractogram): - ''' Class containing information about streamlines. - - Tractogram objects have four main properties: ``header``, ``streamlines``, - ``scalars`` and ``properties``. Tractogram objects are iterable and - produce tuple of ``streamlines``, ``scalars`` and ``properties`` for each - streamline. - - Parameters - ---------- - streamlines_func : coroutine ouputting (Nt,3) array-like (optional) - Function yielding streamlines. One streamline is - an ndarray of shape (Nt,3) where Nt is the number of points of - streamline t. - - scalars_func : coroutine ouputting (Nt,M) array-like (optional) - Function yielding scalars for a particular streamline t. The scalars - are represented as an ndarray of shape (Nt,M) where Nt is the number - of points of that streamline t and M is the number of scalars - associated to each point (excluding the three coordinates). - - properties_func : coroutine ouputting (P,) array-like (optional) - Function yielding properties for a particular streamline t. The - properties are represented as an ndarray of shape (P,) where P is - the number of properties associated to each streamline. - - getitem_func : function `idx -> 3-tuples` (optional) - Function returning a subset of the tractogram given an index or a - slice (i.e. the __getitem__ function to use). - - Notes - ----- - If provided, ``scalars`` and ``properties`` must yield the same number of - values as ``streamlines``. - ''' - def __init__(self, streamlines=lambda:[], data_per_streamline=None, data_per_point=None): - super(LazyTractogram, self).__init__(streamlines, data_per_streamline, data_per_point) - self.nb_streamlines = None - self._data = None - self._getitem = None - self._affine_to_apply = np.eye(4) - - @classmethod - def create_from_data(cls, data_func): - ''' Saves streamlines to a file-like object. - - Parameters - ---------- - data_func : coroutine ouputting tuple (optional) - Function yielding 3-tuples, (streamlines, scalars, properties). - Streamlines are represented as an ndarray of shape (Nt,3), scalars - as an ndarray of shape (Nt,M) and properties as an ndarray of shape - (P,) where Nt is the number of points for a particular - streamline t, M is the number of scalars associated to each point - (excluding the three coordinates) and P is the number of properties - associated to each streamline. - ''' - if not callable(data_func): - raise TypeError("`data` must be a coroutine.") - - lazy_streamlines = cls() - lazy_streamlines._data = data_func - lazy_streamlines.streamlines = lambda: (x[0] for x in data_func()) - lazy_streamlines.scalars = lambda: (x[1] for x in data_func()) - lazy_streamlines.properties = lambda: (x[2] for x in data_func()) - return lazy_streamlines - - @property - def streamlines(self): - if not np.all(self._affine_to_apply == np.eye(4)): - def _transform(): - for s in self._streamlines(): - yield apply_affine(self._affine_to_apply, s) - - return _transform() - - return self._streamlines() - - @streamlines.setter - def streamlines(self, value): - if not callable(value): - raise TypeError("`streamlines` must be a coroutine.") - - self._streamlines = value - - @property - def data_per_streamline(self): - return self._data_per_streamline - - @data_per_streamline.setter - def data_per_streamline(self, value): - if value is None: - value = {} - - self._data_per_streamline = {} - for k, v in value.items(): - if not callable(v): - raise TypeError("`data_per_streamline` must be a dict of coroutines.") - - self._data_per_streamline[k] = v - - @property - def data_per_point(self): - return self._data_per_point - - @data_per_point.setter - def data_per_point(self, value): - if value is None: - value = {} - - self._data_per_point = {} - for k, v in value.items(): - if not callable(v): - raise TypeError("`data_per_point` must be a dict of coroutines.") - - self._data_per_point[k] = v - - @property - def data(self): - if self._data is not None: - return self._data() - - def _gen_data(): - data_per_streamline_generators = {} - for k, v in self.data_per_streamline.items(): - data_per_streamline_generators[k] = iter(v()) - - data_per_point_generators = {} - for k, v in self.data_per_point.items(): - data_per_point_generators[k] = iter(v()) - - for s in self.streamlines: - data_for_streamline = {} - for k, v in data_per_streamline_generators.items(): - data_for_streamline[k] = next(v) - - data_for_points = {} - for k, v in data_per_point_generators.items(): - data_for_points[k] = v() - - yield TractogramItem(s, data_for_streamline, data_for_points) - - return _gen_data() - - @data.setter - def data(self, value): - if not callable(value): - raise TypeError("`data` must be a coroutine.") - - self._data = value - - def __getitem__(self, idx): - if self._getitem is None: - raise AttributeError('`LazyTractogram` does not support indexing.') - - return self._getitem(idx) - - def __iter__(self): - i = 0 - for i, tractogram_item in enumerate(self.data, start=1): - yield tractogram_item - - # To be safe, update information about number of streamlines. - self.nb_streamlines = i - - def __len__(self): - # If length is unknown, we obtain it by iterating through streamlines. - if self.nb_streamlines is None: - warn("Number of streamlines will be determined manually by looping" - " through the streamlines. If you know the actual number of" - " streamlines, you might want to set it beforehand via" - " `self.header.nb_streamlines`." - " Note this will consume any generators used to create this" - " `LazyTractogram` object.", UsageWarning) - return sum(1 for _ in self.streamlines) - - return self.nb_streamlines - - def copy(self): - """ Returns a copy of this `LazyTractogram` object. """ - tractogram = LazyTractogram(self._streamlines, - self._data_per_streamline, - self._data_per_point) - tractogram.nb_streamlines = self.nb_streamlines - tractogram._data = self._data - return tractogram - - def apply_affine(self, affine): - """ Applies an affine transformation on the streamlines. - - Parameters - ---------- - affine : 2D array (4,4) - Transformation that will be applied on each streamline. - """ - self._affine_to_apply = np.dot(affine, self._affine_to_apply) - - -class abstractclassmethod(classmethod): - __isabstractmethod__ = True - - def __init__(self, callable): - callable.__isabstractmethod__ = True - super(abstractclassmethod, self).__init__(callable) - - -class TractogramFile(object): - ''' Convenience class to encapsulate tractogram file format. ''' - __metaclass__ = ABCMeta - - def __init__(self, tractogram, header=None): - self._tractogram = tractogram - self._header = TractogramHeader() if header is None else header - - @property - def tractogram(self): - return self._tractogram - - @property - def streamlines(self): - return self.tractogram.streamlines - - @property - def header(self): - return self._header - - def get_tractogram(self): - return self.tractogram - - def get_streamlines(self): - return self.streamlines - - def get_header(self): - return self.header - - @classmethod - def get_magic_number(cls): - ''' Returns streamlines file's magic number. ''' - raise NotImplementedError() - - @classmethod - def support_data_per_point(cls): - ''' Tells if this tractogram format supports saving data per point. ''' - raise NotImplementedError() - - @classmethod - def support_data_per_streamline(cls): - ''' Tells if this tractogram format supports saving data per streamline. ''' - raise NotImplementedError() - - @classmethod - def is_correct_format(cls, fileobj): - ''' Checks if the file has the right streamlines file format. - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the - beginning of the header). - Returns - ------- - is_correct_format : boolean - Returns True if `fileobj` is in the right streamlines file format. - ''' - raise NotImplementedError() - - @abstractclassmethod - def load(cls, fileobj, lazy_load=True): - ''' Loads streamlines from a file-like object. - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to a streamlines file (and ready to read from the - beginning of the header). - lazy_load : boolean (optional) - Load streamlines in a lazy manner i.e. they will not be kept - in memory. For postprocessing speed, turn off this option. - Returns - ------- - tractogram_file : ``TractogramFile`` object - Returns an object containing tractogram data and header - information. - ''' - raise NotImplementedError() - - @abstractmethod - def save(self, fileobj): - ''' Saves streamlines to a file-like object. - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - opened and ready to write. - ''' - raise NotImplementedError() diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py new file mode 100644 index 0000000000..87671bb5c2 --- /dev/null +++ b/nibabel/streamlines/compact_list.py @@ -0,0 +1,198 @@ +import numpy as np + + +class CompactList(object): + """ Class for compacting list of ndarrays with matching shape except for + the first dimension. + """ + def __init__(self, iterable=None): + """ + Parameters + ---------- + iterable : iterable (optional) + If specified, create a ``CompactList`` object initialized from + iterable's items. Otherwise, create an empty ``CompactList``. + + Notes + ----- + If `iterable` is a ``CompactList`` object, a view is returned and no + memory is allocated. For an actual copy use the `.copy()` method. + """ + # Create new empty `CompactList` object. + self._data = None + self._offsets = [] + self._lengths = [] + + if isinstance(iterable, CompactList): + # Create a view. + self._data = iterable._data + self._offsets = iterable._offsets + self._lengths = iterable._lengths + + elif iterable is not None: + # Initialize the `CompactList` object from iterable's item. + BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. + + offset = 0 + for i, e in enumerate(iterable): + e = np.asarray(e) + if i == 0: + self._data = np.empty((BUFFER_SIZE,) + e.shape[1:], + dtype=e.dtype) + + end = offset + len(e) + if end >= len(self._data): + # Resize is needed (at least `len(e)` items will be added). + self._data.resize((len(self._data) + len(e)+BUFFER_SIZE,) + + self.shape) + + self._offsets.append(offset) + self._lengths.append(len(e)) + self._data[offset:offset+len(e)] = e + offset += len(e) + + # Clear unused memory. + if self._data is not None: + self._data.resize((offset,) + self.shape) + + @property + def shape(self): + """ Returns the matching shape of the elements in this compact list. """ + if self._data is None: + return None + + return self._data.shape[1:] + + def append(self, element): + """ Appends `element` to this compact list. + + Parameters + ---------- + element : ndarray + Element to append. The shape must match already inserted elements + shape except for the first dimension. + + Notes + ----- + If you need to add multiple elements you should consider + `CompactList.extend`. + """ + if self._data is None: + self._data = np.asarray(element).copy() + self._offsets.append(0) + self._lengths.append(len(element)) + return + + if element.shape[1:] != self.shape: + raise ValueError("All dimensions, except the first one," + " must match exactly") + + self._offsets.append(len(self._data)) + self._lengths.append(len(element)) + self._data = np.append(self._data, element, axis=0) + + def extend(self, elements): + """ Appends all `elements` to this compact list. + + Parameters + ---------- + elements : list of ndarrays, ``CompactList`` object + Elements to append. The shape must match already inserted elements + shape except for the first dimension. + + """ + if isinstance(elements, CompactList): + self._data = np.concatenate([self._data, elements._data], axis=0) + lengths = elements._lengths + else: + self._data = np.concatenate([self._data] + list(elements), axis=0) + lengths = map(len, elements) + + idx = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 + self._lengths.extend(lengths) + self._offsets.extend(np.cumsum([idx] + lengths).tolist()[:-1]) + + def copy(self): + """ Creates a copy of this ``CompactList`` object. """ + # We do not simply deepcopy this object since we might have a chance + # to use less memory. For example, if the compact list being copied + # is the result of a slicing operation on a compact list. + clist = CompactList() + total_lengths = np.sum(self._lengths) + clist._data = np.empty((total_lengths,) + self._data.shape[1:], + dtype=self._data.dtype) + + idx = 0 + for offset, length in zip(self._offsets, self._lengths): + clist._offsets.append(idx) + clist._lengths.append(length) + clist._data[idx:idx+length] = self._data[offset:offset+length] + idx += length + + return clist + + def __getitem__(self, idx): + """ Gets element(s) through indexing. + + Parameters + ---------- + idx : int, slice or list + Index of the element(s) to get. + + Returns + ------- + ndarray object(s) + When `idx` is a int, returns a single ndarray. + When `idx` is either a slice or a list, returns a list of ndarrays. + """ + if isinstance(idx, int) or isinstance(idx, np.integer): + start = self._offsets[idx] + return self._data[start:start+self._lengths[idx]] + + elif type(idx) is slice: + compact_list = CompactList() + compact_list._data = self._data + compact_list._offsets = self._offsets[idx] + compact_list._lengths = self._lengths[idx] + return compact_list + + elif type(idx) is list: + compact_list = CompactList() + compact_list._data = self._data + compact_list._offsets = [self._offsets[i] for i in idx] + compact_list._lengths = [self._lengths[i] for i in idx] + return compact_list + + raise TypeError("Index must be a int or a slice! Not " + str(type(idx))) + + def __iter__(self): + if len(self._lengths) != len(self._offsets): + raise ValueError("CompactList object corrupted:" + " len(self._lengths) != len(self._offsets)") + + for offset, lengths in zip(self._offsets, self._lengths): + yield self._data[offset: offset+lengths] + + def __len__(self): + return len(self._offsets) + + def __repr__(self): + return repr(list(self)) + + +def save_compact_list(filename, compact_list): + """ Saves a `CompactList` object to a .npz file. """ + np.savez(filename, + data=compact_list._data, + offsets=compact_list._offsets, + lengths=compact_list._lengths) + + +def load_compact_list(filename): + """ Loads a `CompactList` object from a .npz file. """ + content = np.load(filename) + compact_list = CompactList() + compact_list._data = content["data"] + compact_list._offsets = content["offsets"].tolist() + compact_list._lengths = content["lengths"].tolist() + return compact_list diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py new file mode 100644 index 0000000000..2211d9d303 --- /dev/null +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -0,0 +1,222 @@ +import os +import unittest +import tempfile +import numpy as np + +from nose.tools import assert_equal, assert_raises, assert_true +from numpy.testing import assert_array_equal +from nibabel.externals.six.moves import zip, zip_longest + +from ..compact_list import (CompactList, + load_compact_list, + save_compact_list) + + +class TestCompactList(unittest.TestCase): + + def setUp(self): + rng = np.random.RandomState(42) + self.data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + self.lengths = map(len, self.data) + self.clist = CompactList(self.data) + + def test_creating_empty_compactlist(self): + clist = CompactList() + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_list(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = map(len, data) + + clist = CompactList(data) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + + # Empty list + clist = CompactList([]) + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_generator(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = map(len, data) + + gen = (e for e in data) + clist = CompactList(gen) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + + # Already consumed generator + clist = CompactList(gen) + assert_equal(len(clist), 0) + assert_equal(len(clist._offsets), 0) + assert_equal(len(clist._lengths), 0) + assert_true(clist._data is None) + assert_true(clist.shape is None) + + def test_creating_compactlist_from_compact_list(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = map(len, data) + + clist = CompactList(data) + clist2 = CompactList(clist) + assert_equal(len(clist2), len(data)) + assert_equal(len(clist2._offsets), len(data)) + assert_equal(len(clist2._lengths), len(data)) + assert_equal(clist2._data.shape[0], sum(lengths)) + assert_equal(clist2._data.shape[1], 3) + assert_equal(clist2._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist2._lengths, lengths) + assert_equal(clist2.shape, data[0].shape[1:]) + + def test_compactlist_iter(self): + for e, d in zip(self.clist, self.data): + assert_array_equal(e, d) + + def test_compactlist_copy(self): + clist = self.clist.copy() + assert_array_equal(clist._data, self.clist._data) + assert_true(clist._data is not self.clist._data) + assert_array_equal(clist._offsets, self.clist._offsets) + assert_true(clist._offsets is not self.clist._offsets) + assert_array_equal(clist._lengths, self.clist._lengths) + assert_true(clist._lengths is not self.clist._lengths) + + assert_equal(clist.shape, self.clist.shape) + + # When taking a copy of a `CompactList` generated by slicing. + # Only needed data should be kept. + clist = self.clist[::2].copy() + + assert_true(clist._data.shape[0] < self.clist._data.shape[0]) + assert_true(len(clist) < len(self.clist)) + assert_true(clist._data is not self.clist._data) + + def test_compactlist_append(self): + # Maybe not necessary if `self.setUp` is always called before a + # test method, anyways create a copy just in case. + clist = self.clist.copy() + + rng = np.random.RandomState(1234) + element = rng.rand(rng.randint(10, 50), *self.clist.shape) + clist.append(element) + assert_equal(len(clist), len(self.clist)+1) + assert_equal(clist._offsets[-1], len(self.clist._data)) + assert_equal(clist._lengths[-1], len(element)) + assert_array_equal(clist._data[-len(element):], element) + + # Append with different shape. + element = rng.rand(rng.randint(10, 50), 42) + assert_raises(ValueError, clist.append, element) + + # Append to an empty CompactList. + clist = CompactList() + rng = np.random.RandomState(1234) + shape = (2, 3, 4) + element = rng.rand(rng.randint(10, 50), *shape) + clist.append(element) + + assert_equal(len(clist), 1) + assert_equal(clist._offsets[-1], 0) + assert_equal(clist._lengths[-1], len(element)) + assert_array_equal(clist._data, element) + assert_equal(clist.shape, shape) + + def test_compactlist_extend(self): + # Maybe not necessary if `self.setUp` is always called before a + # test method, anyways create a copy just in case. + clist = self.clist.copy() + + rng = np.random.RandomState(1234) + shape = self.clist.shape + new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(5)] + lengths = map(len, new_data) + clist.extend(new_data) + assert_equal(len(clist), len(self.clist)+len(new_data)) + assert_array_equal(clist._offsets[-len(new_data):], + len(self.clist._data) + np.cumsum([0] + lengths[:-1])) + + assert_equal(clist._lengths[-len(new_data):], lengths) + assert_array_equal(clist._data[-sum(lengths):], + np.concatenate(new_data, axis=0)) + + # Extend with another `CompactList` object. + clist = self.clist.copy() + new_data = CompactList(new_data) + clist.extend(new_data) + assert_equal(len(clist), len(self.clist)+len(new_data)) + assert_array_equal(clist._offsets[-len(new_data):], + len(self.clist._data) + np.cumsum([0] + lengths[:-1])) + + assert_equal(clist._lengths[-len(new_data):], lengths) + assert_array_equal(clist._data[-sum(lengths):], new_data._data) + + def test_compactlist_getitem(self): + # Get one item + for i, e in enumerate(self.clist): + assert_array_equal(self.clist[i], e) + + # Get multiple items (this will create a view). + clist_view = self.clist[range(len(self.clist))] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_true(clist_view._offsets is not self.clist._offsets) + assert_true(clist_view._lengths is not self.clist._lengths) + assert_array_equal(clist_view._offsets, self.clist._offsets) + assert_array_equal(clist_view._lengths, self.clist._lengths) + for e1, e2 in zip_longest(clist_view, self.clist): + assert_array_equal(e1, e2) + + # Get slice (this will create a view). + clist_view = self.clist[::2] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_array_equal(clist_view._offsets, self.clist._offsets[::2]) + assert_array_equal(clist_view._lengths, self.clist._lengths[::2]) + for i, e in enumerate(clist_view): + assert_array_equal(e, self.clist[i*2]) + + +def test_save_and_load_compact_list(): + + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + clist = CompactList() + save_compact_list(f, clist) + f.seek(0, os.SEEK_SET) + loaded_clist = load_compact_list(f) + assert_array_equal(loaded_clist._data, clist._data) + assert_array_equal(loaded_clist._offsets, clist._offsets) + assert_array_equal(loaded_clist._lengths, clist._lengths) + + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + clist = CompactList(data) + save_compact_list(f, clist) + f.seek(0, os.SEEK_SET) + loaded_clist = load_compact_list(f) + assert_array_equal(loaded_clist._data, clist._data) + assert_array_equal(loaded_clist._offsets, clist._offsets) + assert_array_equal(loaded_clist._lengths, clist._lengths) diff --git a/nibabel/streamlines/tests/test_base_format.py b/nibabel/streamlines/tests/test_tractogram.py similarity index 52% rename from nibabel/streamlines/tests/test_base_format.py rename to nibabel/streamlines/tests/test_tractogram.py index 53d8589ab7..f106e5e08b 100644 --- a/nibabel/streamlines/tests/test_base_format.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -1,4 +1,3 @@ -import os import unittest import numpy as np import warnings @@ -6,202 +5,12 @@ from nibabel.testing import assert_arrays_equal, isiterable from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true -from numpy.testing import assert_array_equal, assert_array_almost_equal -from nibabel.externals.six.moves import zip, zip_longest +from numpy.testing import assert_array_equal +from nibabel.externals.six.moves import zip -from .. import base_format -from ..base_format import CompactList -from ..base_format import TractogramItem, Tractogram, LazyTractogram -from ..base_format import UsageWarning - -DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') - - -class TestCompactList(unittest.TestCase): - - def setUp(self): - rng = np.random.RandomState(42) - self.data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - self.lengths = map(len, self.data) - self.clist = CompactList(self.data) - - def test_creating_empty_compactlist(self): - clist = CompactList() - assert_equal(len(clist), 0) - assert_equal(len(clist._offsets), 0) - assert_equal(len(clist._lengths), 0) - assert_true(clist._data is None) - assert_true(clist.shape is None) - - def test_creating_compactlist_from_list(self): - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = map(len, data) - - clist = CompactList(data) - assert_equal(len(clist), len(data)) - assert_equal(len(clist._offsets), len(data)) - assert_equal(len(clist._lengths), len(data)) - assert_equal(clist._data.shape[0], sum(lengths)) - assert_equal(clist._data.shape[1], 3) - assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) - assert_equal(clist._lengths, lengths) - assert_equal(clist.shape, data[0].shape[1:]) - - # Empty list - clist = CompactList([]) - assert_equal(len(clist), 0) - assert_equal(len(clist._offsets), 0) - assert_equal(len(clist._lengths), 0) - assert_true(clist._data is None) - assert_true(clist.shape is None) - - def test_creating_compactlist_from_generator(self): - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = map(len, data) - - gen = (e for e in data) - clist = CompactList(gen) - assert_equal(len(clist), len(data)) - assert_equal(len(clist._offsets), len(data)) - assert_equal(len(clist._lengths), len(data)) - assert_equal(clist._data.shape[0], sum(lengths)) - assert_equal(clist._data.shape[1], 3) - assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) - assert_equal(clist._lengths, lengths) - assert_equal(clist.shape, data[0].shape[1:]) - - # Already consumed generator - clist = CompactList(gen) - assert_equal(len(clist), 0) - assert_equal(len(clist._offsets), 0) - assert_equal(len(clist._lengths), 0) - assert_true(clist._data is None) - assert_true(clist.shape is None) - - def test_creating_compactlist_from_compact_list(self): - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = map(len, data) - - clist = CompactList(data) - clist2 = CompactList(clist) - assert_equal(len(clist2), len(data)) - assert_equal(len(clist2._offsets), len(data)) - assert_equal(len(clist2._lengths), len(data)) - assert_equal(clist2._data.shape[0], sum(lengths)) - assert_equal(clist2._data.shape[1], 3) - assert_equal(clist2._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) - assert_equal(clist2._lengths, lengths) - assert_equal(clist2.shape, data[0].shape[1:]) - - def test_compactlist_iter(self): - for e, d in zip(self.clist, self.data): - assert_array_equal(e, d) - - def test_compactlist_copy(self): - clist = self.clist.copy() - assert_array_equal(clist._data, self.clist._data) - assert_true(clist._data is not self.clist._data) - assert_array_equal(clist._offsets, self.clist._offsets) - assert_true(clist._offsets is not self.clist._offsets) - assert_array_equal(clist._lengths, self.clist._lengths) - assert_true(clist._lengths is not self.clist._lengths) - - assert_equal(clist.shape, self.clist.shape) - - # When taking a copy of a `CompactList` generated by slicing. - # Only needed data should be kept. - clist = self.clist[::2].copy() - - assert_true(clist._data.shape[0] < self.clist._data.shape[0]) - assert_true(len(clist) < len(self.clist)) - assert_true(clist._data is not self.clist._data) - - def test_compactlist_append(self): - # Maybe not necessary if `self.setUp` is always called before a - # test method, anyways create a copy just in case. - clist = self.clist.copy() - - rng = np.random.RandomState(1234) - element = rng.rand(rng.randint(10, 50), *self.clist.shape) - clist.append(element) - assert_equal(len(clist), len(self.clist)+1) - assert_equal(clist._offsets[-1], len(self.clist._data)) - assert_equal(clist._lengths[-1], len(element)) - assert_array_equal(clist._data[-len(element):], element) - - # Append with different shape. - element = rng.rand(rng.randint(10, 50), 42) - assert_raises(ValueError, clist.append, element) - - # Append to an empty CompactList. - clist = CompactList() - rng = np.random.RandomState(1234) - shape = (2, 3, 4) - element = rng.rand(rng.randint(10, 50), *shape) - clist.append(element) - - assert_equal(len(clist), 1) - assert_equal(clist._offsets[-1], 0) - assert_equal(clist._lengths[-1], len(element)) - assert_array_equal(clist._data, element) - assert_equal(clist.shape, shape) - - def test_compactlist_extend(self): - # Maybe not necessary if `self.setUp` is always called before a - # test method, anyways create a copy just in case. - clist = self.clist.copy() - - rng = np.random.RandomState(1234) - shape = self.clist.shape - new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(5)] - lengths = map(len, new_data) - clist.extend(new_data) - assert_equal(len(clist), len(self.clist)+len(new_data)) - assert_array_equal(clist._offsets[-len(new_data):], - len(self.clist._data) + np.cumsum([0] + lengths[:-1])) - - assert_equal(clist._lengths[-len(new_data):], lengths) - assert_array_equal(clist._data[-sum(lengths):], - np.concatenate(new_data, axis=0)) - - # Extend with another `CompactList` object. - clist = self.clist.copy() - new_data = CompactList(new_data) - clist.extend(new_data) - assert_equal(len(clist), len(self.clist)+len(new_data)) - assert_array_equal(clist._offsets[-len(new_data):], - len(self.clist._data) + np.cumsum([0] + lengths[:-1])) - - assert_equal(clist._lengths[-len(new_data):], lengths) - assert_array_equal(clist._data[-sum(lengths):], new_data._data) - - def test_compactlist_getitem(self): - # Get one item - for i, e in enumerate(self.clist): - assert_array_equal(self.clist[i], e) - - # Get multiple items (this will create a view). - clist_view = self.clist[range(len(self.clist))] - assert_true(clist_view is not self.clist) - assert_true(clist_view._data is self.clist._data) - assert_true(clist_view._offsets is not self.clist._offsets) - assert_true(clist_view._lengths is not self.clist._lengths) - assert_array_equal(clist_view._offsets, self.clist._offsets) - assert_array_equal(clist_view._lengths, self.clist._lengths) - for e1, e2 in zip_longest(clist_view, self.clist): - assert_array_equal(e1, e2) - - # Get slice (this will create a view). - clist_view = self.clist[::2] - assert_true(clist_view is not self.clist) - assert_true(clist_view._data is self.clist._data) - assert_array_equal(clist_view._offsets, self.clist._offsets[::2]) - assert_array_equal(clist_view._lengths, self.clist._lengths[::2]) - for i, e in enumerate(clist_view): - assert_array_equal(e, self.clist[i*2]) +from .. import tractogram as module_tractogram +from ..tractogram import UsageWarning +from ..tractogram import TractogramItem, Tractogram, LazyTractogram class TestTractogramItem(unittest.TestCase): @@ -247,7 +56,7 @@ def setUp(self): [0, 0, 1], [1, 0, 0]], dtype="f4") - self.nb_tractogram = len(self.streamlines) + self.nb_streamlines = len(self.streamlines) def test_tractogram_creation(self): # Create an empty tractogram. @@ -379,10 +188,6 @@ def test_tractogram_add_new_data(self): class TestLazyTractogram(unittest.TestCase): def setUp(self): - self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") - self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") - self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") - self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), np.arange(2*3, dtype="f4").reshape((2, 3)), np.arange(5*3, dtype="f4").reshape((5, 3))] @@ -391,71 +196,81 @@ def setUp(self): np.array([(0, 1, 0)]*2, dtype="f4"), np.array([(0, 0, 1)]*5, dtype="f4")] - self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), - np.array([2.11, 2.22], dtype="f4"), - np.array([3.11, 3.22], dtype="f4")] + self.mean_curvature = np.array([1.11, 2.11, 3.11], dtype="f4") + self.mean_color = np.array([[0, 1, 0], + [0, 0, 1], + [1, 0, 0]], dtype="f4") self.nb_streamlines = len(self.streamlines) - self.nb_scalars_per_point = self.colors[0].shape[1] - self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) def test_lazy_tractogram_creation(self): # To create tractogram from arrays use `Tractogram`. assert_raises(TypeError, LazyTractogram, self.streamlines) - # Points, scalars and properties + # Streamlines and other data as generators streamlines = (x for x in self.streamlines) - scalars = (x for x in self.colors) - properties = (x for x in self.mean_curvature_torsion) + data_per_point = {"colors": (x for x in self.colors)} + data_per_streamline = {'mean_curv': (x for x in self.mean_curvature), + 'mean_color': (x for x in self.mean_color)} - # Creating LazyTractogram from generators is not allowed as - # generators get exhausted and are not reusable unline coroutines. + # Creating LazyTractogram with generators is not allowed as + # generators get exhausted and are not reusable unlike coroutines. assert_raises(TypeError, LazyTractogram, streamlines) - assert_raises(TypeError, LazyTractogram, self.streamlines, scalars) - assert_raises(TypeError, LazyTractogram, properties_func=properties) + assert_raises(TypeError, LazyTractogram, + data_per_streamline=data_per_streamline) + assert_raises(TypeError, LazyTractogram, self.streamlines, + data_per_point=data_per_point) # Empty `LazyTractogram` tractogram = LazyTractogram() - with suppress_warnings(): - assert_equal(len(tractogram), 0) + assert_true(isiterable(tractogram)) + assert_equal(len(tractogram), 0) assert_arrays_equal(tractogram.streamlines, []) - assert_arrays_equal(tractogram.scalars, []) - assert_arrays_equal(tractogram.properties, []) - - # Check if we can iterate through the tractogram. - for streamline in tractogram: - pass + assert_equal(tractogram.data_per_point, {}) + assert_equal(tractogram.data_per_streamline, {}) - # Points, scalars and properties + # Create tractogram with streamlines and other data streamlines = lambda: (x for x in self.streamlines) - scalars = lambda: (x for x in self.colors) - properties = lambda: (x for x in self.mean_curvature_torsion) + data_per_point = {"colors": lambda: (x for x in self.colors)} + data_per_streamline = {'mean_curv': lambda: (x for x in self.mean_curvature), + 'mean_color': lambda: (x for x in self.mean_color)} - tractogram = LazyTractogram(streamlines, scalars, properties) - with suppress_warnings(): - assert_equal(len(tractogram), self.nb_streamlines) + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + + assert_true(isiterable(tractogram)) + assert_equal(len(tractogram), self.nb_streamlines) # Coroutines get re-called and creates new iterators. - assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.scalars, self.colors) - assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) - assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.scalars, self.colors) - assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) + for i in range(2): + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_arrays_equal(tractogram.data_per_streamline['mean_curv'], + self.mean_curvature) + assert_arrays_equal(tractogram.data_per_streamline['mean_color'], + self.mean_color) + assert_arrays_equal(tractogram.data_per_point['colors'], + self.colors) # Create `LazyTractogram` from a coroutine yielding 3-tuples - data = lambda: (x for x in zip(self.streamlines, self.colors, self.mean_curvature_torsion)) - - tractogram = LazyTractogram.create_from_data(data) - with suppress_warnings(): - assert_equal(len(tractogram), self.nb_streamlines) + def _data_gen(): + for d in zip(self.streamlines, self.colors, + self.mean_curvature, self.mean_color): + data_for_points = {'colors': d[1]} + data_for_streamline = {'mean_curv': d[2], + 'mean_color': d[3]} + yield TractogramItem(d[0], data_for_streamline, data_for_points) + + tractogram = LazyTractogram.create_from(_data_gen) + assert_true(isiterable(tractogram)) + assert_equal(len(tractogram), self.nb_streamlines) assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.scalars, self.colors) - assert_arrays_equal(tractogram.properties, self.mean_curvature_torsion) - - # Check if we can iterate through the tractogram. - for streamline in tractogram: - pass + assert_arrays_equal(tractogram.data_per_streamline['mean_curv'], + self.mean_curvature) + assert_arrays_equal(tractogram.data_per_streamline['mean_color'], + self.mean_color) + assert_arrays_equal(tractogram.data_per_point['colors'], + self.colors) def test_lazy_tractogram_indexing(self): streamlines = lambda: (x for x in self.streamlines) @@ -473,7 +288,8 @@ def getitem_without_properties(idx): return list(zip(self.streamlines[idx], self.colors[idx])) - tractogram = LazyTractogram(streamlines, scalars, properties, getitem_without_properties) + tractogram = LazyTractogram(streamlines, scalars, properties, + getitem_without_properties) streamlines, scalars = tractogram[0] assert_array_equal(streamlines, self.streamlines[0]) assert_array_equal(scalars, self.colors[0]) @@ -491,7 +307,7 @@ def test_lazy_tractogram_len(self): scalars = lambda: (x for x in self.colors) properties = lambda: (x for x in self.mean_curvature_torsion) - with clear_and_catch_warnings(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[module_tractogram]) as w: warnings.simplefilter("always") # Always trigger warnings. # Calling `len` will create new generators each time. @@ -510,19 +326,20 @@ def test_lazy_tractogram_len(self): assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 2) - with clear_and_catch_warnings(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[module_tractogram]) as w: # Once we iterated through the tractogram, we know the length. tractogram = LazyTractogram(streamlines, scalars, properties) assert_true(tractogram.header.nb_streamlines is None) for streamline in tractogram: pass - assert_equal(tractogram.header.nb_streamlines, len(self.streamlines)) + assert_equal(tractogram.header.nb_streamlines, + len(self.streamlines)) # This should *not* produce a warning. assert_equal(len(tractogram), len(self.streamlines)) assert_equal(len(w), 0) - with clear_and_catch_warnings(record=True, modules=[base_format]) as w: + with clear_and_catch_warnings(record=True, modules=[module_tractogram]) as w: # It first checks if number of tractogram is in the header. tractogram = LazyTractogram(streamlines, scalars, properties) tractogram.header.nb_streamlines = 1234 diff --git a/nibabel/streamlines/tests/test_utils.py b/nibabel/streamlines/tests/test_utils.py index b5c67e0770..6c3bf096a6 100644 --- a/nibabel/streamlines/tests/test_utils.py +++ b/nibabel/streamlines/tests/test_utils.py @@ -1,6 +1,4 @@ import os -import unittest -import tempfile import numpy as np import nibabel as nib @@ -8,9 +6,7 @@ from numpy.testing import assert_array_equal from nose.tools import assert_equal, assert_raises, assert_true -from ..base_format import CompactList from ..utils import pop, get_affine_from_reference -from ..utils import save_compact_list, load_compact_list def test_peek(): @@ -36,26 +32,3 @@ def test_get_affine_from_reference(): # Get affine from a `SpatialImage` using by its filename. assert_array_equal(get_affine_from_reference(filename), affine) - - -def test_save_and_load_compact_list(): - - with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: - clist = CompactList() - save_compact_list(f, clist) - f.seek(0, os.SEEK_SET) - loaded_clist = load_compact_list(f) - assert_array_equal(loaded_clist._data, clist._data) - assert_array_equal(loaded_clist._offsets, clist._offsets) - assert_array_equal(loaded_clist._lengths, clist._lengths) - - with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - clist = CompactList(data) - save_compact_list(f, clist) - f.seek(0, os.SEEK_SET) - loaded_clist = load_compact_list(f) - assert_array_equal(loaded_clist._data, clist._data) - assert_array_equal(loaded_clist._offsets, clist._offsets) - assert_array_equal(loaded_clist._lengths, clist._lengths) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py new file mode 100644 index 0000000000..c7e0593065 --- /dev/null +++ b/nibabel/streamlines/tractogram.py @@ -0,0 +1,390 @@ +import numpy as np +from warnings import warn + +from nibabel.affines import apply_affine + +from .compact_list import CompactList + + +class UsageWarning(Warning): + pass + + +class TractogramItem(object): + """ Class containing information about one streamline. + + ``TractogramItem`` objects have three main properties: `streamline`, + `data_for_streamline`, and `data_for_points`. + + Parameters + ---------- + streamline : ndarray of shape (N, 3) + Points of this streamline represented as an ndarray of shape (N, 3) + where N is the number of points. + + data_for_streamline : dict + + data_for_points : dict + """ + def __init__(self, streamline, data_for_streamline, data_for_points): + self.streamline = np.asarray(streamline) + self.data_for_streamline = data_for_streamline + self.data_for_points = data_for_points + + def __iter__(self): + return iter(self.streamline) + + def __len__(self): + return len(self.streamline) + + +class Tractogram(object): + """ Class containing information about streamlines. + + Tractogram objects have three main properties: ``streamlines`` + + Parameters + ---------- + streamlines : list of ndarray of shape (Nt, 3) + Sequence of T streamlines. One streamline is an ndarray of shape + (Nt, 3) where Nt is the number of points of streamline t. + + data_per_streamline : dictionary of list of ndarray of shape (P,) + Sequence of T ndarrays of shape (P,) where T is the number of + streamlines defined by ``streamlines``, P is the number of properties + associated to each streamline. + + data_per_point : dictionary of list of ndarray of shape (Nt, M) + Sequence of T ndarrays of shape (Nt, M) where T is the number of + streamlines defined by ``streamlines``, Nt is the number of points + for a particular streamline t and M is the number of scalars + associated to each point (excluding the three coordinates). + + """ + def __init__(self, streamlines=None, + data_per_streamline=None, + data_per_point=None): + + self.streamlines = streamlines + self.data_per_streamline = data_per_streamline + self.data_per_point = data_per_point + + @property + def streamlines(self): + return self._streamlines + + @streamlines.setter + def streamlines(self, value): + self._streamlines = CompactList(value) + + @property + def data_per_streamline(self): + return self._data_per_streamline + + @data_per_streamline.setter + def data_per_streamline(self, value): + if value is None: + value = {} + + self._data_per_streamline = {} + for k, v in value.items(): + self._data_per_streamline[k] = np.asarray(v) + + @property + def data_per_point(self): + return self._data_per_point + + @data_per_point.setter + def data_per_point(self, value): + if value is None: + value = {} + + self._data_per_point = {} + for k, v in value.items(): + self._data_per_point[k] = CompactList(v) + + def __iter__(self): + for i in range(len(self.streamlines)): + yield self[i] + + def __getitem__(self, idx): + pts = self.streamlines[idx] + + new_data_per_streamline = {} + for key in self.data_per_streamline: + new_data_per_streamline[key] = self.data_per_streamline[key][idx] + + new_data_per_point = {} + for key in self.data_per_point: + new_data_per_point[key] = self.data_per_point[key][idx] + + if type(idx) is slice: + return Tractogram(pts, new_data_per_streamline, new_data_per_point) + + return TractogramItem(pts, new_data_per_streamline, new_data_per_point) + + def __len__(self): + return len(self.streamlines) + + def copy(self): + """ Returns a copy of this `Tractogram` object. """ + new_data_per_streamline = {} + for key in self.data_per_streamline: + new_data_per_streamline[key] = self.data_per_streamline[key].copy() + + new_data_per_point = {} + for key in self.data_per_point: + new_data_per_point[key] = self.data_per_point[key].copy() + + tractogram = Tractogram(self.streamlines.copy(), + new_data_per_streamline, + new_data_per_point) + return tractogram + + def apply_affine(self, affine): + """ Applies an affine transformation on the points of each streamline. + + This is performed in-place. + + Parameters + ---------- + affine : 2D array (4,4) + Transformation that will be applied on each streamline. + """ + if len(self.streamlines) == 0: + return + + BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. + for i in range(0, len(self.streamlines._data), BUFFER_SIZE): + pts = self.streamlines._data[i:i+BUFFER_SIZE] + self.streamlines._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) + + +import collections +class LazyTractogram(Tractogram): + ''' Class containing information about streamlines. + + Tractogram objects have four main properties: ``header``, ``streamlines``, + ``scalars`` and ``properties``. Tractogram objects are iterable and + produce tuple of ``streamlines``, ``scalars`` and ``properties`` for each + streamline. + + Parameters + ---------- + streamlines_func : coroutine ouputting (Nt,3) array-like (optional) + Function yielding streamlines. One streamline is + an ndarray of shape (Nt,3) where Nt is the number of points of + streamline t. + + scalars_func : coroutine ouputting (Nt,M) array-like (optional) + Function yielding scalars for a particular streamline t. The scalars + are represented as an ndarray of shape (Nt,M) where Nt is the number + of points of that streamline t and M is the number of scalars + associated to each point (excluding the three coordinates). + + properties_func : coroutine ouputting (P,) array-like (optional) + Function yielding properties for a particular streamline t. The + properties are represented as an ndarray of shape (P,) where P is + the number of properties associated to each streamline. + + getitem_func : function `idx -> 3-tuples` (optional) + Function returning a subset of the tractogram given an index or a + slice (i.e. the __getitem__ function to use). + + Notes + ----- + If provided, ``scalars`` and ``properties`` must yield the same number of + values as ``streamlines``. + ''' + + class LazyDict(collections.MutableMapping): + """ Internal dictionary with lazy evaluations. """ + + def __init__(self, *args, **kwargs): + self.store = dict() + self.update(dict(*args, **kwargs)) # Use update to set keys. + + def __getitem__(self, key): + return self.store[key]() + + def __setitem__(self, key, value): + if value is not None and not callable(value): + raise TypeError("`value` must be a coroutine or None.") + + self.store[key] = value + + def __delitem__(self, key): + del self.store[key] + + def __iter__(self): + return iter(self.store) + + def __len__(self): + return len(self.store) + + def __init__(self, streamlines=None, data_per_streamline=None, + data_per_point=None): + super(LazyTractogram, self).__init__(streamlines, data_per_streamline, + data_per_point) + self._nb_streamlines = None + self._data = None + self._getitem = None + self._affine_to_apply = np.eye(4) + + @classmethod + def create_from(cls, data_func): + ''' Creates a `LazyTractogram` from a coroutine yielding + `TractogramItem` objects. + + Parameters + ---------- + data_func : coroutine yielding `TractogramItem` objects + A function that whenever it is called starts yielding + `TractogramItem` objects that should be part of this + LazyTractogram. + + ''' + if not callable(data_func): + raise TypeError("`data_func` must be a coroutine.") + + lazy_tractogram = cls() + lazy_tractogram._data = data_func + + # Set data_per_streamline using data_func + def _gen(key): + return lambda: (t.data_for_streamline[key] for t in data_func()) + + data_per_streamline_keys = next(data_func()).data_for_streamline.keys() + for k in data_per_streamline_keys: + lazy_tractogram._data_per_streamline[k] = _gen(k) + + # Set data_per_point using data_func + def _gen(key): + return lambda: (t.data_for_points[key] for t in data_func()) + + data_per_point_keys = next(data_func()).data_for_points.keys() + for k in data_per_point_keys: + lazy_tractogram._data_per_point[k] = _gen(k) + + return lazy_tractogram + + @property + def streamlines(self): + streamlines_gen = iter([]) + if self._streamlines is not None: + streamlines_gen = self._streamlines() + elif self._data is not None: + streamlines_gen = (t.streamline for t in self._data()) + + # Check if we need to apply an affine. + if not np.all(self._affine_to_apply == np.eye(4)): + def _apply_affine(): + for s in streamlines_gen: + yield apply_affine(self._affine_to_apply, s) + + streamlines_gen = _apply_affine() + + return streamlines_gen + + @streamlines.setter + def streamlines(self, value): + if value is not None and not callable(value): + raise TypeError("`streamlines` must be a coroutine.") + + self._streamlines = value + + @property + def data_per_streamline(self): + return self._data_per_streamline + + @data_per_streamline.setter + def data_per_streamline(self, value): + if value is None: + value = {} + + self._data_per_streamline = LazyTractogram.LazyDict(value) + + @property + def data_per_point(self): + return self._data_per_point + + @data_per_point.setter + def data_per_point(self, value): + if value is None: + value = {} + + self._data_per_point = LazyTractogram.LazyDict(value) + + @property + def data(self): + if self._data is not None: + return self._data() + + def _gen_data(): + data_per_streamline_generators = {} + for k, v in self.data_per_streamline.items(): + data_per_streamline_generators[k] = iter(v) + + data_per_point_generators = {} + for k, v in self.data_per_point.items(): + data_per_point_generators[k] = iter(v) + + for s in self.streamlines: + data_for_streamline = {} + for k, v in data_per_streamline_generators.items(): + data_for_streamline[k] = next(v) + + data_for_points = {} + for k, v in data_per_point_generators.items(): + data_for_points[k] = next(v) + + yield TractogramItem(s, data_for_streamline, data_for_points) + + return _gen_data() + + def __getitem__(self, idx): + if self._getitem is None: + raise AttributeError('`LazyTractogram` does not support indexing.') + + return self._getitem(idx) + + def __iter__(self): + i = 0 + for i, tractogram_item in enumerate(self.data, start=1): + yield tractogram_item + + # Keep how many streamlines there are in this tractogram. + self._nb_streamlines = i + + def __len__(self): + # Check if we know how many streamlines there are. + if self._nb_streamlines is None: + warn("Number of streamlines will be determined manually by looping" + " through the streamlines. If you know the actual number of" + " streamlines, you might want to set it beforehand via" + " `self.header.nb_streamlines`." + " Note this will consume any generators used to create this" + " `LazyTractogram` object.", UsageWarning) + # Count the number of streamlines. + self._nb_streamlines = sum(1 for _ in self.streamlines) + + return self._nb_streamlines + + def copy(self): + """ Returns a copy of this `LazyTractogram` object. """ + tractogram = LazyTractogram(self._streamlines, + self._data_per_streamline, + self._data_per_point) + tractogram.nb_streamlines = self.nb_streamlines + tractogram._data = self._data + return tractogram + + def apply_affine(self, affine): + """ Applies an affine transformation on the streamlines. + + Parameters + ---------- + affine : 2D array (4,4) + Transformation that will be applied on each streamline. + """ + self._affine_to_apply = np.dot(affine, self._affine_to_apply) diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py new file mode 100644 index 0000000000..2ec5b69a68 --- /dev/null +++ b/nibabel/streamlines/tractogram_file.py @@ -0,0 +1,103 @@ +from abc import ABCMeta, abstractmethod, abstractproperty + +from .header import TractogramHeader + + +class abstractclassmethod(classmethod): + __isabstractmethod__ = True + + def __init__(self, callable): + callable.__isabstractmethod__ = True + super(abstractclassmethod, self).__init__(callable) + + +class TractogramFile(object): + ''' Convenience class to encapsulate tractogram file format. ''' + __metaclass__ = ABCMeta + + def __init__(self, tractogram, header=None): + self._tractogram = tractogram + self._header = TractogramHeader() if header is None else header + + @property + def tractogram(self): + return self._tractogram + + @property + def streamlines(self): + return self.tractogram.streamlines + + @property + def header(self): + return self._header + + def get_tractogram(self): + return self.tractogram + + def get_streamlines(self): + return self.streamlines + + def get_header(self): + return self.header + + @classmethod + def get_magic_number(cls): + ''' Returns streamlines file's magic number. ''' + raise NotImplementedError() + + @classmethod + def support_data_per_point(cls): + ''' Tells if this tractogram format supports saving data per point. ''' + raise NotImplementedError() + + @classmethod + def support_data_per_streamline(cls): + ''' Tells if this tractogram format supports saving data per streamline. ''' + raise NotImplementedError() + + @classmethod + def is_correct_format(cls, fileobj): + ''' Checks if the file has the right streamlines file format. + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header). + Returns + ------- + is_correct_format : boolean + Returns True if `fileobj` is in the right streamlines file format. + ''' + raise NotImplementedError() + + @abstractclassmethod + def load(cls, fileobj, lazy_load=True): + ''' Loads streamlines from a file-like object. + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to a streamlines file (and ready to read from the + beginning of the header). + lazy_load : boolean (optional) + Load streamlines in a lazy manner i.e. they will not be kept + in memory. For postprocessing speed, turn off this option. + Returns + ------- + tractogram_file : ``TractogramFile`` object + Returns an object containing tractogram data and header + information. + ''' + raise NotImplementedError() + + @abstractmethod + def save(self, fileobj): + ''' Saves streamlines to a file-like object. + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + opened and ready to write. + ''' + raise NotImplementedError() diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index b93a65f5d9..6d1fc66397 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -3,9 +3,10 @@ # Documentation available here: # http://www.trackvis.org/docs/?subsect=fileformat -import struct import os +import struct import warnings +import itertools import numpy as np import nibabel as nib @@ -13,12 +14,13 @@ from nibabel.openers import Opener from nibabel.volumeutils import (native_code, swapped_code) -from nibabel.streamlines.base_format import TractogramFile -from nibabel.streamlines.base_format import DataError, HeaderError, HeaderWarning -from nibabel.streamlines.base_format import Tractogram, LazyTractogram -from nibabel.streamlines.header import Field +from .compact_list import CompactList +from .tractogram_file import TractogramFile +from .base_format import DataError, HeaderError, HeaderWarning +from .tractogram import Tractogram, LazyTractogram +from .header import Field -from nibabel.streamlines.utils import get_affine_from_reference +from .utils import get_affine_from_reference # Definition of trackvis header structure. # See http://www.trackvis.org/docs/?subsect=fileformat @@ -298,9 +300,6 @@ def write(self, tractogram): self.file.write(self.header.tostring()) -import itertools -from nibabel.streamlines.base_format import CompactList - def create_compactlist_from_generator(gen): BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 7fbfab5106..7bbbe1ef8d 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -33,22 +33,3 @@ def pop(iterable): "Returns the next item from the iterable else None" value = list(itertools.islice(iterable, 1)) return value[0] if len(value) > 0 else None - - -def save_compact_list(filename, compact_list): - """ Saves a `CompactList` object to a .npz file. """ - np.savez(filename, - data=compact_list._data, - offsets=compact_list._offsets, - lengths=compact_list._lengths) - - -def load_compact_list(filename): - """ Loads a `CompactList` object from a .npz file. """ - from .base_format import CompactList - content = np.load(filename) - compact_list = CompactList() - compact_list._data = content["data"] - compact_list._offsets = content["offsets"].tolist() - compact_list._lengths = content["lengths"].tolist() - return compact_list From 100cb849a9bd306f845cccd20c14a99ba6907d1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 17 Nov 2015 17:33:55 -0500 Subject: [PATCH 027/135] Save scalars and properties name when using the TRK file format --- nibabel/streamlines/tractogram.py | 2 +- nibabel/streamlines/trk.py | 119 ++++++++++++++++++++++-------- 2 files changed, 89 insertions(+), 32 deletions(-) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index c7e0593065..f7907ca888 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -1,4 +1,5 @@ import numpy as np +import collections from warnings import warn from nibabel.affines import apply_affine @@ -160,7 +161,6 @@ def apply_affine(self, affine): self.streamlines._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) -import collections class LazyTractogram(Tractogram): ''' Class containing information about streamlines. diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 6d1fc66397..b8c325aca0 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -17,11 +17,14 @@ from .compact_list import CompactList from .tractogram_file import TractogramFile from .base_format import DataError, HeaderError, HeaderWarning -from .tractogram import Tractogram, LazyTractogram +from .tractogram import TractogramItem, Tractogram, LazyTractogram from .header import Field from .utils import get_affine_from_reference +MAX_NB_NAMED_SCALARS_PER_POINT = 10 +MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE = 10 + # Definition of trackvis header structure. # See http://www.trackvis.org/docs/?subsect=fileformat # See http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html @@ -30,9 +33,9 @@ (Field.VOXEL_SIZES, 'f4', 3), (Field.ORIGIN, 'f4', 3), (Field.NB_SCALARS_PER_POINT, 'h'), - ('scalar_name', 'S20', 10), + ('scalar_name', 'S20', MAX_NB_NAMED_SCALARS_PER_POINT), (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), - ('property_name', 'S20', 10), + ('property_name', 'S20', MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), ('reserved', 'S508'), (Field.VOXEL_ORDER, 'S4'), ('pad2', 'S4'), @@ -58,9 +61,9 @@ (Field.VOXEL_SIZES, 'f4', 3), (Field.ORIGIN, 'f4', 3), (Field.NB_SCALARS_PER_POINT, 'h'), - ('scalar_name', 'S20', 10), + ('scalar_name', 'S20', MAX_NB_NAMED_SCALARS_PER_POINT), (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), - ('property_name', 'S20', 10), + ('property_name', 'S20', MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), (Field.VOXEL_TO_RASMM, 'f4', (4, 4)), # new field for version 2 ('reserved', 'S444'), (Field.VOXEL_ORDER, 'S4'), @@ -264,9 +267,9 @@ def write(self, tractogram): raise DataError("Missing scalars for some points!") points = np.asarray(t.streamline, dtype=f4_dtype) - keys = sorted(t.data_for_points.keys()) + keys = sorted(t.data_for_points.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] scalars = np.asarray([t.data_for_points[k] for k in keys], dtype=f4_dtype).reshape((len(points), -1)) - keys = sorted(t.data_for_streamline.keys()) + keys = sorted(t.data_for_streamline.keys())[:MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE] properties = np.asarray([t.data_for_streamline[k] for k in keys], dtype=f4_dtype).flatten() data = struct.pack(i4_dtype.str[:-1], len(points)) @@ -524,36 +527,59 @@ def _apply_transform(trk_reader): trk_reader yield pts, scals, props - data = lambda: _apply_transform(trk_reader) - tractogram = LazyTractogram.create_from_data(data) + def _read(): + for pts, scals, props in trk_reader: + # TODO + data_for_streamline = {} + data_for_points = {} + yield TractogramItem(pts, data_for_streamline, data_for_points) - # Overwrite scalars and properties if there is none - if trk_reader.header[Field.NB_SCALARS_PER_POINT] == 0: - tractogram.scalars = lambda: [] - if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] == 0: - tractogram.properties = lambda: [] + tractogram = LazyTractogram.create_from(_read) else: streamlines, scalars, properties = create_compactlist_from_generator(trk_reader) tractogram = Tractogram(streamlines) if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: - if len(trk_reader.header['scalar_name'][0]) > 0: - for i in range(trk_reader.header[Field.NB_SCALARS_PER_POINT]): - clist = CompactList() - clist._data = scalars._data[:, i] - clist._offsets = scalars._offsets - clist._lengths = scalars._lengths - tractogram.data_per_point[trk_reader.header['scalar_name'][i]] = clist - else: - tractogram.data_per_point['scalars'] = scalars + cpt = 0 + for scalar_name in trk_reader.header['scalar_name']: + if len(scalar_name) == 0: + continue + + nb_scalars = np.fromstring(scalar_name[-1], np.int8) + + clist = CompactList() + clist._data = scalars._data[:, cpt:cpt+nb_scalars] + clist._offsets = scalars._offsets + clist._lengths = scalars._lengths + + scalar_name = scalar_name.split('\x00')[0] + tractogram.data_per_point[scalar_name] = clist + cpt += nb_scalars + + if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: + #tractogram.data_per_point['scalars'] = scalars + clist = CompactList() + clist._data = scalars._data[:, cpt:] + clist._offsets = scalars._offsets + clist._lengths = scalars._lengths + tractogram.data_per_point['scalars'] = clist if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: - if len(trk_reader.header['property_name'][0]) > 0: - for i in range(trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]): - tractogram.data_per_streamline[trk_reader.header['property_name'][i]] = properties[:, i] - else: - tractogram.data_per_streamline['properties'] = properties + cpt = 0 + for property_name in trk_reader.header['property_name']: + if len(property_name) == 0: + continue + + nb_properties = np.fromstring(property_name[-1], np.int8) + property_name = property_name.split('\x00')[0] + tractogram.data_per_streamline[property_name] = properties[:, cpt:cpt+nb_properties] + cpt += nb_properties + + if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: + #tractogram.data_per_streamline['properties'] = properties + nb_properties = np.fromstring(property_name[-1], np.int8) + tractogram.data_per_streamline['properties'] = properties[:, cpt:] # Bring tractogram to RAS+ and mm space tractogram.apply_affine(affine) @@ -578,9 +604,40 @@ def save(self, fileobj): pointing to TRK file (and ready to read from the beginning of the TRK header data). ''' - # Update header using the tractogram. - self.header.nb_scalars_per_point = sum(map(lambda e: len(e[0]), self.tractogram.data_per_point.values())) - self.header.nb_properties_per_streamline = sum(map(lambda e: len(e[0]), self.tractogram.data_per_streamline.values())) + # Compute how many properties per streamline the tractogram has. + self.header.nb_properties_per_streamline = 0 + self.header.extra['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') + data_for_streamline = self.tractogram[0].data_for_streamline + for i, k in enumerate(sorted(data_for_streamline.keys())): + if i >= MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: + warnings.warn(("Can only store {0} named properties: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) + + if len(k) > 19: + warnings.warn(("Property name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) + + v = data_for_streamline[k] + self.header.nb_properties_per_streamline += v.shape[0] + + property_name = k[:19].ljust(19, '\x00') + np.array(v.shape[0], dtype=np.int8).tostring() + self.header.extra['property_name'][i] = property_name + + # Compute how many scalars per point the tractogram has. + self.header.nb_scalars_per_point = 0 + self.header.extra['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') + data_for_points = self.tractogram[0].data_for_points + for i, k in enumerate(sorted(data_for_points.keys())): + if i >= MAX_NB_NAMED_SCALARS_PER_POINT: + warnings.warn(("Can only store {0} named scalars: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) + + if len(k) > 19: + warnings.warn(("Scalar name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) + + v = data_for_points[k] + self.header.nb_scalars_per_point += v.shape[1] + + scalar_name = k[:19].ljust(19, '\x00') + np.array(v.shape[1], dtype=np.int8).tostring() + self.header.extra['scalar_name'][i] = scalar_name + trk_writer = TrkWriter(fileobj, self.header) trk_writer.write(self.tractogram) From 0e84d01069e1f7f522d1a0a906f4eb8288cb20d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 17 Nov 2015 22:47:48 -0500 Subject: [PATCH 028/135] BF: Extend on empty CompactList is now allowed --- nibabel/streamlines/compact_list.py | 4 ++++ nibabel/streamlines/tests/test_compact_list.py | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 87671bb5c2..8ead3ed50d 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -101,6 +101,10 @@ def extend(self, elements): shape except for the first dimension. """ + if self._data is None: + elem = np.asarray(elements[0]) + self._data = np.zeros((0, elem.shape[1]), dtype=elem.dtype) + if isinstance(elements, CompactList): self._data = np.concatenate([self._data, elements._data], axis=0) lengths = elements._lengths diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 2211d9d303..90783a29b0 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -173,6 +173,14 @@ def test_compactlist_extend(self): assert_equal(clist._lengths[-len(new_data):], lengths) assert_array_equal(clist._data[-sum(lengths):], new_data._data) + # Test extending an empty CompactList + clist = CompactList() + clist.extend(new_data) + assert_equal(len(clist), len(new_data)) + assert_array_equal(clist._offsets, new_data._offsets) + assert_array_equal(clist._lengths, new_data._lengths) + assert_array_equal(clist._data, new_data._data) + def test_compactlist_getitem(self): # Get one item for i, e in enumerate(self.clist): From 9ed55b8316ca015ed8ff6f29752064bd9bf64a59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 18 Nov 2015 10:30:05 -0500 Subject: [PATCH 029/135] BF: Support creating TractogramHeader from dict --- nibabel/streamlines/header.py | 22 +++++++++++++++++++++- nibabel/streamlines/tests/test_trk.py | 13 +++++++++++-- nibabel/streamlines/tractogram_file.py | 3 ++- nibabel/streamlines/trk.py | 3 +-- 4 files changed, 35 insertions(+), 6 deletions(-) diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 2298e395b2..4b7ea30605 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -25,13 +25,33 @@ class Field: class TractogramHeader(object): - def __init__(self): + def __init__(self, hdr=None): self._nb_streamlines = None self._nb_scalars_per_point = None self._nb_properties_per_streamline = None self._to_world_space = np.eye(4) self.extra = OrderedDict() + if type(hdr) is dict: + if Field.NB_POINTS in hdr: + self.nb_streamlines = hdr[Field.NB_POINTS] + + if Field.NB_SCALARS_PER_POINT in hdr: + self.nb_scalars_per_point = hdr[Field.NB_SCALARS_PER_POINT] + + if Field.NB_PROPERTIES_PER_STREAMLINE in hdr: + self.nb_properties_per_streamline = hdr[Field.NB_PROPERTIES_PER_STREAMLINE] + + if Field.VOXEL_TO_RASMM in hdr: + self.to_world_space = hdr[Field.VOXEL_TO_RASMM] + + elif type(hdr) is TractogramHeader: + self.nb_streamlines = hdr.nb_streamlines + self.nb_scalars_per_point = hdr.nb_scalars_per_point + self.nb_properties_per_streamline = hdr.nb_properties_per_streamline + self.to_world_space = hdr.to_world_space + self.extra = copy.deepcopy(hdr.extra) + @property def to_world_space(self): return self._to_world_space diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 1b1c2ab7a6..3322140097 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -9,8 +9,8 @@ from nose.tools import assert_equal, assert_raises, assert_true from .. import base_format -from ..base_format import Tractogram, LazyTractogram -from ..base_format import DataError, HeaderError, HeaderWarning, UsageWarning +from ..tractogram import Tractogram, LazyTractogram +from ..base_format import DataError, HeaderError, HeaderWarning#, UsageWarning #from .. import trk from ..trk import TrkFile @@ -254,6 +254,15 @@ def test_write_erroneous_file(self): trk = TrkFile(tractogram, ref=self.affine) assert_raises(IndexError, trk.save, BytesIO()) + def test_load_write_simple_file(self): + trk = TrkFile.load(self.simple_trk_filename, lazy_load=False) + trk_file = BytesIO() + trk.save(trk_file) + + # trk = TrkFile.load(self.simple_trk_filename, lazy_load=True) + # check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) + + # def test_write_file_lazy_tractogram(self): # streamlines = lambda: (point for point in self.streamlines) # scalars = lambda: (scalar for scalar in self.colors) diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index 2ec5b69a68..d3f54c2a84 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -17,7 +17,8 @@ class TractogramFile(object): def __init__(self, tractogram, header=None): self._tractogram = tractogram - self._header = TractogramHeader() if header is None else header + #self._header = TractogramHeader() if header is None else header + self._header = TractogramHeader(header) @property def tractogram(self): diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index b8c325aca0..92b5466ef0 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -578,7 +578,6 @@ def _read(): if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: #tractogram.data_per_streamline['properties'] = properties - nb_properties = np.fromstring(property_name[-1], np.int8) tractogram.data_per_streamline['properties'] = properties[:, cpt:] # Bring tractogram to RAS+ and mm space @@ -592,7 +591,7 @@ def _read(): #if tractogram.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: # raise HeaderError("'nb_properties_per_streamline' does not match.") - return cls(tractogram, ref=affine, header=trk_reader.header) + return cls(tractogram, header=trk_reader.header, ref=affine) def save(self, fileobj): ''' Saves tractogram to a file-like object using TRK format. From d3af0b31f99bebd025e4b2f76c3561719cf9ad16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 18 Nov 2015 15:24:03 -0500 Subject: [PATCH 030/135] BF: Fixed creating TractogramHeader from another TractogramHeader when its nb_streamlines was None --- nibabel/streamlines/header.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 4b7ea30605..3fac6952bd 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -46,10 +46,10 @@ def __init__(self, hdr=None): self.to_world_space = hdr[Field.VOXEL_TO_RASMM] elif type(hdr) is TractogramHeader: - self.nb_streamlines = hdr.nb_streamlines - self.nb_scalars_per_point = hdr.nb_scalars_per_point - self.nb_properties_per_streamline = hdr.nb_properties_per_streamline - self.to_world_space = hdr.to_world_space + self._nb_streamlines = hdr._nb_streamlines + self._nb_scalars_per_point = hdr._nb_scalars_per_point + self._nb_properties_per_streamline = hdr._nb_properties_per_streamline + self._to_world_space = hdr._to_world_space self.extra = copy.deepcopy(hdr.extra) @property From aaf2e96ae5f354ab562784ad39b39d5de590ad51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 19 Nov 2015 01:01:01 -0500 Subject: [PATCH 031/135] BF: Not all property and scalar name were save in the TRK header --- nibabel/streamlines/tests/data/complex.trk | Bin 1228 -> 1228 bytes nibabel/streamlines/tests/test_streamlines.py | 6 +++--- nibabel/streamlines/tests/test_trk.py | 3 +++ nibabel/streamlines/trk.py | 11 ++++++++++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/nibabel/streamlines/tests/data/complex.trk b/nibabel/streamlines/tests/data/complex.trk index 9bfbe5ea60917d6f98920b3a08e460be5fb6732d..e6155154c43fa34577517fefce8f1ac14b9a4ce7 100644 GIT binary patch delta 64 zcmX@Zd4_X>8hdeaVoqXF@kEJ(VoVGLMfnA(MJ1W3#SAdOG|_O|+C=Nk69YsiGc#@nQRf(KfLLU+0Mio2i2-bzw=hdE0suGt4gvrG diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 5dabad1815..e27e7cc42a 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -12,9 +12,9 @@ from nibabel.testing import assert_arrays_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true, assert_false -from ..base_format import Tractogram, LazyTractogram, TractogramFile -from ..base_format import HeaderError, UsageWarning -from ..header import Field +from ..tractogram import Tractogram, LazyTractogram +from ..tractogram_file import TractogramFile +from ..tractogram import UsageWarning from .. import trk DATA_PATH = pjoin(os.path.dirname(__file__), 'data') diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 3322140097..ed7317db5a 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -206,6 +206,9 @@ def test_write_complex_file(self): assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) + #raw = trk_file.read() + #raw = raw[:38] + "\x00"*200 + raw[38+200:] # Overwrite the sclar name section + #raw = raw[:240] + "\x00"*200 + raw[240+200:] # Overwrite the property name section assert_equal(open(self.complex_trk_filename, 'rb').read(), trk_file.read()) def test_write_erroneous_file(self): diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 92b5466ef0..6c97c1b1fd 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -218,11 +218,20 @@ def create_empty_header(cls): def __init__(self, fileobj, header): self.header = self.create_empty_header() - # Override hdr's fields by those contain in `header`. + # Override hdr's fields by those contained in `header`. for k, v in header.extra.items(): if k in header_2_dtype.fields.keys(): self.header[k] = v + # TODO: Fix that ugly patch. + # Because the assignment operator on ndarray of string only copy the + # first entry, we have to do it explicitly! + if "property_name" in header.extra: + self.header["property_name"][:] = header.extra["property_name"][:] + + if "scalar_name" in header.extra: + self.header["scalar_name"][:] = header.extra["scalar_name"][:] + self.header[Field.NB_STREAMLINES] = 0 if header.nb_streamlines is not None: self.header[Field.NB_STREAMLINES] = header.nb_streamlines From 277f211d5e4eff2b45edf1294067acb5e7661269 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 19 Nov 2015 14:51:13 -0500 Subject: [PATCH 032/135] ENH: CompactList support advance indexing with ndarray of data type bool --- nibabel/streamlines/compact_list.py | 47 +++++++++++-------- .../streamlines/tests/test_compact_list.py | 13 +++++ 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 8ead3ed50d..fb98536bab 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -154,18 +154,27 @@ def __getitem__(self, idx): return self._data[start:start+self._lengths[idx]] elif type(idx) is slice: - compact_list = CompactList() - compact_list._data = self._data - compact_list._offsets = self._offsets[idx] - compact_list._lengths = self._lengths[idx] - return compact_list + clist = CompactList() + clist._data = self._data + clist._offsets = self._offsets[idx] + clist._lengths = self._lengths[idx] + return clist elif type(idx) is list: - compact_list = CompactList() - compact_list._data = self._data - compact_list._offsets = [self._offsets[i] for i in idx] - compact_list._lengths = [self._lengths[i] for i in idx] - return compact_list + clist = CompactList() + clist._data = self._data + clist._offsets = [self._offsets[i] for i in idx] + clist._lengths = [self._lengths[i] for i in idx] + return clist + + elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: + clist = CompactList() + clist._data = self._data + clist._offsets = [self._offsets[i] + for i, take_it in enumerate(idx) if take_it] + clist._lengths = [self._lengths[i] + for i, take_it in enumerate(idx) if take_it] + return clist raise TypeError("Index must be a int or a slice! Not " + str(type(idx))) @@ -184,19 +193,19 @@ def __repr__(self): return repr(list(self)) -def save_compact_list(filename, compact_list): +def save_compact_list(filename, clist): """ Saves a `CompactList` object to a .npz file. """ np.savez(filename, - data=compact_list._data, - offsets=compact_list._offsets, - lengths=compact_list._lengths) + data=clist._data, + offsets=clist._offsets, + lengths=clist._lengths) def load_compact_list(filename): """ Loads a `CompactList` object from a .npz file. """ content = np.load(filename) - compact_list = CompactList() - compact_list._data = content["data"] - compact_list._offsets = content["offsets"].tolist() - compact_list._lengths = content["lengths"].tolist() - return compact_list + clist = CompactList() + clist._data = content["data"] + clist._offsets = content["offsets"].tolist() + clist._lengths = content["lengths"].tolist() + return clist diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 90783a29b0..71dac112b7 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -206,6 +206,19 @@ def test_compactlist_getitem(self): for i, e in enumerate(clist_view): assert_array_equal(e, self.clist[i*2]) + # Use advance indexing with ndarray of data type bool. + idx = np.array([False, True, True, False, True]) + clist_view = self.clist[idx] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_array_equal(clist_view._offsets, + np.asarray(self.clist._offsets)[idx]) + assert_array_equal(clist_view._lengths, + np.asarray(self.clist._lengths)[idx]) + assert_array_equal(clist_view[0], self.clist[1]) + assert_array_equal(clist_view[1], self.clist[2]) + assert_array_equal(clist_view[2], self.clist[4]) + def test_save_and_load_compact_list(): From b8d47fddb3bf832cfd78acd4bd6ed0bb4c4b7310 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 19 Nov 2015 14:59:59 -0500 Subject: [PATCH 033/135] ENH: Tractogram support advance indexing with ndarray of data type bool or integer --- nibabel/streamlines/tractogram.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index f7907ca888..c4ed299dc5 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -111,35 +111,35 @@ def __iter__(self): def __getitem__(self, idx): pts = self.streamlines[idx] - new_data_per_streamline = {} + data_per_streamline = {} for key in self.data_per_streamline: - new_data_per_streamline[key] = self.data_per_streamline[key][idx] + data_per_streamline[key] = self.data_per_streamline[key][idx] - new_data_per_point = {} + data_per_point = {} for key in self.data_per_point: - new_data_per_point[key] = self.data_per_point[key][idx] + data_per_point[key] = self.data_per_point[key][idx] - if type(idx) is slice: - return Tractogram(pts, new_data_per_streamline, new_data_per_point) + if isinstance(idx, int) or isinstance(idx, np.integer): + return TractogramItem(pts, data_per_streamline, data_per_point) - return TractogramItem(pts, new_data_per_streamline, new_data_per_point) + return Tractogram(pts, data_per_streamline, data_per_point) def __len__(self): return len(self.streamlines) def copy(self): """ Returns a copy of this `Tractogram` object. """ - new_data_per_streamline = {} + data_per_streamline = {} for key in self.data_per_streamline: - new_data_per_streamline[key] = self.data_per_streamline[key].copy() + data_per_streamline[key] = self.data_per_streamline[key].copy() - new_data_per_point = {} + data_per_point = {} for key in self.data_per_point: - new_data_per_point[key] = self.data_per_point[key].copy() + data_per_point[key] = self.data_per_point[key].copy() tractogram = Tractogram(self.streamlines.copy(), - new_data_per_streamline, - new_data_per_point) + data_per_streamline, + data_per_point) return tractogram def apply_affine(self, affine): From 72278b191b72e7e53c659c8af6108297aa4a9649 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 Nov 2015 01:52:21 -0500 Subject: [PATCH 034/135] Added support for voxel order other than RAS in TRK file format --- nibabel/streamlines/tests/data/complex.trk | Bin 1228 -> 1296 bytes nibabel/streamlines/tests/test_trk.py | 117 ++++++++--- nibabel/streamlines/tractogram_file.py | 3 +- nibabel/streamlines/trk.py | 234 +++++++++++++-------- 4 files changed, 238 insertions(+), 116 deletions(-) diff --git a/nibabel/streamlines/tests/data/complex.trk b/nibabel/streamlines/tests/data/complex.trk index e6155154c43fa34577517fefce8f1ac14b9a4ce7..8aa099becd3bf9436d6cd0cc4ded087c3e794f8b 100644 GIT binary patch literal 1296 zcmWFua&-1)U<5-3h6Z~CW`F}0hUEO5{GwvG0EoeymWaX!aTqZ~2AKdWLvCtfUOcLI zm?2`NMP-R4rA4V=Co_V@N`Riu%+G^*VvzX`6j$hg5;1hMM)`v^1cDrciEc&eX>8WanLs@W4h#$v zfOrECg8)SC?3pt_IRxL9i`}7GgwKH$B8X-lTpdsx-8_igteKKPxeE>q3?G1)!x1D3 u0S)#LISHVc1`u1I$Qc_O1NC_T#Ug+>14Rz#XP_JmG}zm6xdYwk=Kuh;XGW#~ delta 157 zcmbQhb%t|-iWoCPadKi#Vo@;z5@4Rp$h1X_iJ_n= MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: + warnings.warn(("Can only store {0} named properties: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) + + if len(k) > 19: + warnings.warn(("Property name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) + + v = data_for_streamline[k] + property_name = k[:19].ljust(19, '\x00') + np.array(v.shape[0], dtype=np.int8).tostring() + self.header['property_name'][i] = property_name + + # Update the 'scalar_name' field using 'data_per_point' of the tractogram. + data_for_points = tractogram[0].data_for_points + data_for_points_keys = sorted(data_for_points.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] + self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') + for i, k in enumerate(data_for_points_keys): + if i >= MAX_NB_NAMED_SCALARS_PER_POINT: + warnings.warn(("Can only store {0} named scalars: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) + + if len(k) > 19: + warnings.warn(("Scalar name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) + + v = data_for_points[k] + scalar_name = k[:19].ljust(19, '\x00') + np.array(v.shape[1], dtype=np.int8).tostring() + self.header['scalar_name'][i] = scalar_name + + + # `Tractogram` streamlines are in RAS+ and mm space, we will compute + # the affine matrix that will bring them back to 'voxelmm' as required + # by the TRK format. + affine = np.eye(4) # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas streamlines passed in parameters assume (0,0,0) - # to be the center of the voxel. Thus, streamlines are shifted of - # half a voxel. - affine[:-1, -1] += np.array(self.header[Field.VOXEL_SIZES])/2. + # voxel whereas `Tractogram` streamlines assume (0,0,0) is the + # center of the voxel. Thus, streamlines are shifted of half a voxel. + offset = np.eye(4) + offset[:-1, -1] += np.array(self.header[Field.VOXEL_SIZES])/2. + affine = np.dot(offset, affine) + + + # Applied the inverse of the affine found in the TRK header. + # rasmm -> voxel + affine = np.dot(np.linalg.inv(self.header[Field.VOXEL_TO_RASMM]), affine) + + # If the voxel order implied by the affine does not match the voxel + # order in the TRK header, change the orientation. + # voxel (affine) -> voxel (header) + header_ornt = self.header[Field.VOXEL_ORDER] + affine_ornt = "".join(nib.orientations.aff2axcodes(self.header[Field.VOXEL_TO_RASMM])) + header_ornt = nib.orientations.axcodes2ornt(header_ornt) + affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) + ornt = nib.orientations.ornt_transform(affine_ornt, header_ornt) + M = nib.orientations.inv_ornt_aff(ornt, self.header[Field.DIMENSIONS]) + affine = np.dot(M, affine) - tractogram.apply_affine(affine) + # Finally send the streamlines in mm space. + # voxel -> voxelmm + scale = np.eye(4) + scale[range(3), range(3)] *= self.header[Field.VOXEL_SIZES] + affine = np.dot(scale, affine) + + # The TRK format uses float32 as the data type for points. + affine = affine.astype(np.float32) for t in tractogram: if any((len(d) != len(t.streamline) for d in t.data_for_points.values())): raise DataError("Missing scalars for some points!") - points = np.asarray(t.streamline, dtype=f4_dtype) - keys = sorted(t.data_for_points.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] - scalars = np.asarray([t.data_for_points[k] for k in keys], dtype=f4_dtype).reshape((len(points), -1)) - keys = sorted(t.data_for_streamline.keys())[:MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE] - properties = np.asarray([t.data_for_streamline[k] for k in keys], dtype=f4_dtype).flatten() + points = apply_affine(affine, np.asarray(t.streamline, dtype=f4_dtype)) + scalars = [np.asarray(t.data_for_points[k], dtype=f4_dtype) for k in data_for_points_keys] + scalars = np.concatenate([np.ndarray((len(points), 0), dtype=f4_dtype)] + scalars, axis=1) + properties = [np.asarray(t.data_for_streamline[k], dtype=f4_dtype) for k in data_for_streamline_keys] + properties = np.concatenate([np.array([], dtype=f4_dtype)] + properties) data = struct.pack(i4_dtype.str[:-1], len(points)) - data += np.concatenate((points, scalars), axis=1).tostring() + data += np.concatenate([points, scalars], axis=1).tostring() data += properties.tostring() self.file.write(data) @@ -291,8 +383,7 @@ def write(self, tractogram): self.nb_scalars += scalars.size self.nb_properties += len(properties) - # Either correct or warn if header and data are incoherent. - #TODO: add a warn option as a function parameter + # Use those values to update the header. nb_scalars_per_point = self.nb_scalars / self.nb_points nb_properties_per_streamline = self.nb_properties / self.nb_streamlines @@ -429,8 +520,12 @@ def __init__(self, tractogram, header=None, ref=np.eye(4)): Streamlines of the tractogram are assumed to be in *RAS+* and *mm* space where coordinate (0,0,0) refers to the center of the voxel. """ + if header is None: + header_rec = TrkWriter.create_empty_header() + header = dict(zip(header_rec.dtype.names, header_rec)) + super(TrkFile, self).__init__(tractogram, header) - self._affine = get_affine_from_reference(ref) + #self._affine = get_affine_from_reference(ref) @classmethod def get_magic_number(cls): @@ -498,13 +593,20 @@ def load(cls, fileobj, lazy_load=False): ''' trk_reader = TrkReader(fileobj) - # TRK's streamlines are in 'voxelmm' space, we send them to rasmm. - # First send them to voxel space. + # TRK's streamlines are in 'voxelmm' space, we will compute the + # affine matrix that will bring them back to RAS+ and mm space. affine = np.eye(4) - affine[range(3), range(3)] /= trk_reader.header[Field.VOXEL_SIZES] - # If voxel order implied from the affine does not match the voxel - # order save in the TRK header, change the orientation. + # The affine matrix found in the TRK header requires the points to be + # in the voxel space. + # voxelmm -> voxel + scale = np.eye(4) + scale[range(3), range(3)] /= trk_reader.header[Field.VOXEL_SIZES] + affine = np.dot(scale, affine) + + # If the voxel order implied by the affine does not match the voxel + # order in the TRK header, change the orientation. + # voxel (header) -> voxel (affine) header_ornt = trk_reader.header[Field.VOXEL_ORDER] affine_ornt = "".join(nib.orientations.aff2axcodes(trk_reader.header[Field.VOXEL_TO_RASMM])) header_ornt = nib.orientations.axcodes2ornt(header_ornt) @@ -513,14 +615,16 @@ def load(cls, fileobj, lazy_load=False): M = nib.orientations.inv_ornt_aff(ornt, trk_reader.header[Field.DIMENSIONS]) affine = np.dot(M, affine) - # Applied the affine going from voxel space to rasmm. + # Applied the affine found in the TRK header. + # voxel -> rasmm affine = np.dot(trk_reader.header[Field.VOXEL_TO_RASMM], affine) - # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas streamlines returned assume (0,0,0) to be the - # center of the voxel. Thus, streamlines are shifted of half - #a voxel. - affine[:-1, -1] -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. + # TrackVis considers coordinate (0,0,0) to be the corner of the voxel + # whereas streamlines returned assume (0,0,0) to be the center of the + # voxel. Thus, streamlines are shifted of half a voxel. + offset = np.eye(4) + offset[:-1, -1] -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. + affine = np.dot(offset, affine) if lazy_load: # TODO when LazyTractogram has been refactored. @@ -590,7 +694,7 @@ def _read(): tractogram.data_per_streamline['properties'] = properties[:, cpt:] # Bring tractogram to RAS+ and mm space - tractogram.apply_affine(affine) + tractogram.apply_affine(affine.astype(np.float32)) ## Perform some integrity checks #if tractogram.header.voxel_sizes != trk_reader.header[Field.VOXEL_SIZES]: @@ -612,40 +716,6 @@ def save(self, fileobj): pointing to TRK file (and ready to read from the beginning of the TRK header data). ''' - # Compute how many properties per streamline the tractogram has. - self.header.nb_properties_per_streamline = 0 - self.header.extra['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') - data_for_streamline = self.tractogram[0].data_for_streamline - for i, k in enumerate(sorted(data_for_streamline.keys())): - if i >= MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: - warnings.warn(("Can only store {0} named properties: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) - - if len(k) > 19: - warnings.warn(("Property name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) - - v = data_for_streamline[k] - self.header.nb_properties_per_streamline += v.shape[0] - - property_name = k[:19].ljust(19, '\x00') + np.array(v.shape[0], dtype=np.int8).tostring() - self.header.extra['property_name'][i] = property_name - - # Compute how many scalars per point the tractogram has. - self.header.nb_scalars_per_point = 0 - self.header.extra['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') - data_for_points = self.tractogram[0].data_for_points - for i, k in enumerate(sorted(data_for_points.keys())): - if i >= MAX_NB_NAMED_SCALARS_PER_POINT: - warnings.warn(("Can only store {0} named scalars: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) - - if len(k) > 19: - warnings.warn(("Scalar name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) - - v = data_for_points[k] - self.header.nb_scalars_per_point += v.shape[1] - - scalar_name = k[:19].ljust(19, '\x00') + np.array(v.shape[1], dtype=np.int8).tostring() - self.header.extra['scalar_name'][i] = scalar_name - trk_writer = TrkWriter(fileobj, self.header) trk_writer.write(self.tractogram) From 899f6a6884e0fdc68c4962101a4a551ee021ff97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 Nov 2015 02:21:35 -0500 Subject: [PATCH 035/135] Fixed LazyTractogram --- nibabel/streamlines/tests/test_tractogram.py | 87 +++++++++----------- nibabel/streamlines/tractogram.py | 6 +- 2 files changed, 39 insertions(+), 54 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index f106e5e08b..fe7b414ef6 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -203,6 +203,10 @@ def setUp(self): self.nb_streamlines = len(self.streamlines) + self.colors_func = lambda: (x for x in self.colors) + self.mean_curvature_func = lambda: (x for x in self.mean_curvature) + self.mean_color_func = lambda: (x for x in self.mean_color) + def test_lazy_tractogram_creation(self): # To create tractogram from arrays use `Tractogram`. assert_raises(TypeError, LazyTractogram, self.streamlines) @@ -231,9 +235,9 @@ def test_lazy_tractogram_creation(self): # Create tractogram with streamlines and other data streamlines = lambda: (x for x in self.streamlines) - data_per_point = {"colors": lambda: (x for x in self.colors)} - data_per_streamline = {'mean_curv': lambda: (x for x in self.mean_curvature), - 'mean_color': lambda: (x for x in self.mean_color)} + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} tractogram = LazyTractogram(streamlines, data_per_streamline=data_per_streamline, @@ -274,75 +278,60 @@ def _data_gen(): def test_lazy_tractogram_indexing(self): streamlines = lambda: (x for x in self.streamlines) - scalars = lambda: (x for x in self.colors) - properties = lambda: (x for x in self.mean_curvature_torsion) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} # By default, `LazyTractogram` object does not support indexing. - tractogram = LazyTractogram(streamlines, scalars, properties) + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) assert_raises(AttributeError, tractogram.__getitem__, 0) - # Create a `LazyTractogram` object with indexing support. - def getitem_without_properties(idx): - if isinstance(idx, int) or isinstance(idx, np.integer): - return self.streamlines[idx], self.colors[idx] - - return list(zip(self.streamlines[idx], self.colors[idx])) - - tractogram = LazyTractogram(streamlines, scalars, properties, - getitem_without_properties) - streamlines, scalars = tractogram[0] - assert_array_equal(streamlines, self.streamlines[0]) - assert_array_equal(scalars, self.colors[0]) - - streamlines, scalars = zip(*tractogram[::-1]) - assert_arrays_equal(streamlines, self.streamlines[::-1]) - assert_arrays_equal(scalars, self.colors[::-1]) - - streamlines, scalars = zip(*tractogram[:-1]) - assert_arrays_equal(streamlines, self.streamlines[:-1]) - assert_arrays_equal(scalars, self.colors[:-1]) - def test_lazy_tractogram_len(self): streamlines = lambda: (x for x in self.streamlines) - scalars = lambda: (x for x in self.colors) - properties = lambda: (x for x in self.mean_curvature_torsion) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} - with clear_and_catch_warnings(record=True, modules=[module_tractogram]) as w: + modules = [module_tractogram] # Modules for which to catch warnings. + with clear_and_catch_warnings(record=True, modules=modules) as w: warnings.simplefilter("always") # Always trigger warnings. # Calling `len` will create new generators each time. - tractogram = LazyTractogram(streamlines, scalars, properties) + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + assert_true(tractogram._nb_streamlines is None) + # This should produce a warning message. assert_equal(len(tractogram), self.nb_streamlines) + assert_equal(tractogram._nb_streamlines, self.nb_streamlines) assert_equal(len(w), 1) - tractogram = LazyTractogram(streamlines, scalars, properties) - # This should still produce a warning message. + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + + # New instances should still produce a warning message. assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 2) assert_true(issubclass(w[-1].category, UsageWarning)) - # This should *not* produce a warning. + # Calling again 'len' again should *not* produce a warning. assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 2) - with clear_and_catch_warnings(record=True, modules=[module_tractogram]) as w: + with clear_and_catch_warnings(record=True, modules=modules) as w: # Once we iterated through the tractogram, we know the length. - tractogram = LazyTractogram(streamlines, scalars, properties) - assert_true(tractogram.header.nb_streamlines is None) - for streamline in tractogram: - pass - assert_equal(tractogram.header.nb_streamlines, - len(self.streamlines)) - # This should *not* produce a warning. - assert_equal(len(tractogram), len(self.streamlines)) - assert_equal(len(w), 0) + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) - with clear_and_catch_warnings(record=True, modules=[module_tractogram]) as w: - # It first checks if number of tractogram is in the header. - tractogram = LazyTractogram(streamlines, scalars, properties) - tractogram.header.nb_streamlines = 1234 + assert_true(tractogram._nb_streamlines is None) + isiterable(tractogram) # Force to iterate through all streamlines. + assert_equal(tractogram._nb_streamlines, len(self.streamlines)) # This should *not* produce a warning. - assert_equal(len(tractogram), 1234) + assert_equal(len(tractogram), len(self.streamlines)) assert_equal(len(w), 0) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index c4ed299dc5..2f5a641b58 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -228,7 +228,6 @@ def __init__(self, streamlines=None, data_per_streamline=None, data_per_point) self._nb_streamlines = None self._data = None - self._getitem = None self._affine_to_apply = np.eye(4) @classmethod @@ -343,10 +342,7 @@ def _gen_data(): return _gen_data() def __getitem__(self, idx): - if self._getitem is None: - raise AttributeError('`LazyTractogram` does not support indexing.') - - return self._getitem(idx) + raise AttributeError('`LazyTractogram` does not support indexing.') def __iter__(self): i = 0 From 463ef198221255733aa0c1aaa7061da264f19dec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 Nov 2015 02:48:11 -0500 Subject: [PATCH 036/135] BF: Handle empty LazyTractogram --- nibabel/streamlines/tests/test_tractogram.py | 13 ++++- nibabel/streamlines/tractogram.py | 30 +++++++----- nibabel/streamlines/trk.py | 50 +++++++++++++------- 3 files changed, 63 insertions(+), 30 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index fe7b414ef6..c07b548ae1 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -256,7 +256,18 @@ def test_lazy_tractogram_creation(self): assert_arrays_equal(tractogram.data_per_point['colors'], self.colors) - # Create `LazyTractogram` from a coroutine yielding 3-tuples + def test_lazy_tractogram_create_from(self): + # Create `LazyTractogram` from a coroutine yielding nothing (i.e empty). + _empty_data_gen = lambda: iter([]) + + tractogram = LazyTractogram.create_from(_empty_data_gen) + assert_true(isiterable(tractogram)) + assert_equal(len(tractogram), 0) + assert_arrays_equal(tractogram.streamlines, []) + assert_equal(tractogram.data_per_point, {}) + assert_equal(tractogram.data_per_streamline, {}) + + # Create `LazyTractogram` from a coroutine yielding TractogramItem def _data_gen(): for d in zip(self.streamlines, self.colors, self.mean_curvature, self.mean_color): diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 2f5a641b58..23e14d5685 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -249,21 +249,27 @@ def create_from(cls, data_func): lazy_tractogram = cls() lazy_tractogram._data = data_func - # Set data_per_streamline using data_func - def _gen(key): - return lambda: (t.data_for_streamline[key] for t in data_func()) + try: + first_item = next(data_func()) - data_per_streamline_keys = next(data_func()).data_for_streamline.keys() - for k in data_per_streamline_keys: - lazy_tractogram._data_per_streamline[k] = _gen(k) + # Set data_per_streamline using data_func + def _gen(key): + return lambda: (t.data_for_streamline[key] for t in data_func()) - # Set data_per_point using data_func - def _gen(key): - return lambda: (t.data_for_points[key] for t in data_func()) + data_per_streamline_keys = first_item.data_for_streamline.keys() + for k in data_per_streamline_keys: + lazy_tractogram._data_per_streamline[k] = _gen(k) - data_per_point_keys = next(data_func()).data_for_points.keys() - for k in data_per_point_keys: - lazy_tractogram._data_per_point[k] = _gen(k) + # Set data_per_point using data_func + def _gen(key): + return lambda: (t.data_for_points[key] for t in data_func()) + + data_per_point_keys = first_item.data_for_points.keys() + for k in data_per_point_keys: + lazy_tractogram._data_per_point[k] = _gen(k) + + except StopIteration: + pass return lazy_tractogram diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 27430df421..278b6e2ebe 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -627,24 +627,42 @@ def load(cls, fileobj, lazy_load=False): affine = np.dot(offset, affine) if lazy_load: - # TODO when LazyTractogram has been refactored. - def _apply_transform(trk_reader): - for pts, scals, props in trk_reader: - # TRK's streamlines are in 'voxelmm' space, we send them to voxel space. - pts = pts / trk_reader.header[Field.VOXEL_SIZES] - # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas streamlines returned assume (0,0,0) to be the - # center of the voxel. Thus, streamlines are shifted of half - #a voxel. - pts -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. - trk_reader - yield pts, scals, props + #pts, scals, props = next(iter(trk_reader)) + + data_per_point_slice = {} + if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: + cpt = 0 + for scalar_name in trk_reader.header['scalar_name']: + if len(scalar_name) == 0: + continue + + nb_scalars = np.fromstring(scalar_name[-1], np.int8) + scalar_name = scalar_name.split('\x00')[0] + data_per_point_slice[scalar_name] = slice(cpt, cpt+nb_scalars) + cpt += nb_scalars + + if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: + data_per_point_slice['scalars'] = slice(cpt, trk_reader.header[Field.NB_SCALARS_PER_POINT]) + + data_per_streamline_slice = {} + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + cpt = 0 + for property_name in trk_reader.header['property_name']: + if len(property_name) == 0: + continue + + nb_properties = np.fromstring(property_name[-1], np.int8) + property_name = property_name.split('\x00')[0] + data_per_streamline_slice[property_name] = slice(cpt, cpt+nb_properties) + cpt += nb_properties + + if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: + data_per_streamline_slice['properties'] = slice(cpt, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]) def _read(): for pts, scals, props in trk_reader: - # TODO - data_for_streamline = {} - data_for_points = {} + data_for_streamline = {k: props[:, v] for k, v in data_per_streamline_slice.items()} + data_for_points = {k: scals[:, v] for k, v in data_per_point_slice.items()} yield TractogramItem(pts, data_for_streamline, data_for_points) tractogram = LazyTractogram.create_from(_read) @@ -671,7 +689,6 @@ def _read(): cpt += nb_scalars if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: - #tractogram.data_per_point['scalars'] = scalars clist = CompactList() clist._data = scalars._data[:, cpt:] clist._offsets = scalars._offsets @@ -690,7 +707,6 @@ def _read(): cpt += nb_properties if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: - #tractogram.data_per_streamline['properties'] = properties tractogram.data_per_streamline['properties'] = properties[:, cpt:] # Bring tractogram to RAS+ and mm space From fe6f601dd337ca92d17bb283ff5550a15e4ad341 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 Nov 2015 03:42:29 -0500 Subject: [PATCH 037/135] TRK supports LazyTractogram --- nibabel/streamlines/tests/test_tractogram.py | 22 +++++- nibabel/streamlines/tests/test_trk.py | 82 ++++++-------------- nibabel/streamlines/tractogram.py | 2 +- nibabel/streamlines/trk.py | 75 ++++++++++-------- 4 files changed, 89 insertions(+), 92 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index c07b548ae1..493be83301 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -5,7 +5,7 @@ from nibabel.testing import assert_arrays_equal, isiterable from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, assert_array_almost_equal from nibabel.externals.six.moves import zip from .. import tractogram as module_tractogram @@ -346,3 +346,23 @@ def test_lazy_tractogram_len(self): # This should *not* produce a warning. assert_equal(len(tractogram), len(self.streamlines)) assert_equal(len(w), 0) + + def test_lazy_tractogram_apply_affine(self): + streamlines = lambda: (x for x in self.streamlines) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} + + affine = np.eye(4) + scaling = np.array((1, 2, 3), dtype=float) + affine[range(3), range(3)] = scaling + + tractogram = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + + tractogram.apply_affine(affine) + assert_true(isiterable(tractogram)) + assert_equal(len(tractogram), len(self.streamlines)) + for s1, s2 in zip(tractogram.streamlines, self.streamlines): + assert_array_almost_equal(s1, s2*scaling) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index d731873cea..4abd5b688c 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -48,6 +48,7 @@ def assert_tractogram_equal(t1, t2): def check_tractogram(tractogram, streamlines, data_per_streamline, data_per_point): # Check data + assert_true(isiterable(tractogram)) assert_equal(len(tractogram), len(streamlines)) assert_arrays_equal(tractogram.streamlines, streamlines) @@ -59,8 +60,6 @@ def check_tractogram(tractogram, streamlines, data_per_streamline, data_per_poin assert_arrays_equal(tractogram.data_per_point[key], data_per_point[key]) - assert_true(isiterable(tractogram)) - class TestTRK(unittest.TestCase): @@ -118,31 +117,22 @@ def setUp(self): self.affine = np.eye(4) def test_load_empty_file(self): - trk = TrkFile.load(self.empty_trk_filename, lazy_load=False) - check_tractogram(trk.tractogram, [], {}, {}) - - trk = TrkFile.load(self.empty_trk_filename, lazy_load=True) - # Suppress warning about loading a TRK file in lazy mode with count=0. - #with suppress_warnings(): - check_tractogram(trk.tractogram, [], [], []) + for lazy_load in [False, True]: + trk = TrkFile.load(self.empty_trk_filename, lazy_load=lazy_load) + check_tractogram(trk.tractogram, [], {}, {}) def test_load_simple_file(self): - trk = TrkFile.load(self.simple_trk_filename, lazy_load=False) - check_tractogram(trk.tractogram, self.streamlines, {}, {}) - - # trk = TrkFile.load(self.simple_trk_filename, lazy_load=True) - # check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) + for lazy_load in [False, True]: + trk = TrkFile.load(self.simple_trk_filename, lazy_load=lazy_load) + check_tractogram(trk.tractogram, self.streamlines, {}, {}) def test_load_complex_file(self): - trk = TrkFile.load(self.complex_trk_filename, lazy_load=False) - check_tractogram(trk.tractogram, - self.streamlines, - data_per_point=self.data_per_point, - data_per_streamline=self.data_per_streamline) - - # trk = TrkFile.load(self.complex_trk_filename, lazy_load=True) - # check_tractogram(trk.tractogram, self.nb_streamlines, - # self.streamlines, self.colors, self.mean_curvature_torsion) + for lazy_load in [False, True]: + trk = TrkFile.load(self.complex_trk_filename, lazy_load=lazy_load) + check_tractogram(trk.tractogram, + self.streamlines, + data_per_point=self.data_per_point, + data_per_streamline=self.data_per_streamline) def test_load_file_with_wrong_information(self): trk_file = open(self.simple_trk_filename, 'rb').read() @@ -153,12 +143,6 @@ def test_load_file_with_wrong_information(self): trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) check_tractogram(trk.tractogram, self.streamlines, {}, {}) - # trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=True) - # with clear_and_catch_warnings(record=True, modules=[base_format]) as w: - # check_tractogram(trk.tractogram, self.streamlines, {}, {}) - # assert_equal(len(w), 1) - # assert_true(issubclass(w[0].category, UsageWarning)) - # Simulate a TRK file where `voxel_order` was not provided. voxel_order = np.zeros(1, dtype="|S3").tostring() new_trk_file = trk_file[:948] + voxel_order + trk_file[948+3:] @@ -286,13 +270,18 @@ def test_write_erroneous_file(self): trk = TrkFile(tractogram, ref=self.affine) assert_raises(IndexError, trk.save, BytesIO()) - def test_load_write_simple_file(self): - trk = TrkFile.load(self.simple_trk_filename, lazy_load=False) - trk_file = BytesIO() - trk.save(trk_file) + def test_load_write_file(self): + for filename in [self.empty_trk_filename, self.simple_trk_filename, self.complex_trk_filename]: + for lazy_load in [False, True]: + trk = TrkFile.load(filename, lazy_load=lazy_load) + trk_file = BytesIO() + trk.save(trk_file) + + loaded_trk = TrkFile.load(filename, lazy_load=False) + assert_tractogram_equal(loaded_trk.tractogram, trk.tractogram) - # trk = TrkFile.load(self.simple_trk_filename, lazy_load=True) - # check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, [], []) + trk_file.seek(0, os.SEEK_SET) + #assert_equal(open(filename, 'rb').read(), trk_file.read()) def test_load_write_LPS_file(self): trk = TrkFile.load(self.simple_LPS_trk_filename, lazy_load=False) @@ -312,26 +301,3 @@ def test_load_write_LPS_file(self): trk_file.seek(0, os.SEEK_SET) assert_equal(open(self.simple_LPS_trk_filename, 'rb').read(), trk_file.read()) - - #check_tractogram(trk.tractogram, self.nb_streamlines, self.streamlines, {}, {}) - - - - # def test_write_file_lazy_tractogram(self): - # streamlines = lambda: (point for point in self.streamlines) - # scalars = lambda: (scalar for scalar in self.colors) - # properties = lambda: (prop for prop in self.mean_curvature_torsion) - - # tractogram = LazyTractogram(streamlines, scalars, properties) - # # No need to manually set `nb_streamlines` in the header since we count - # # them as writing. - # #tractogram.header.nb_streamlines = self.nb_streamlines - - # trk_file = BytesIO() - # trk = TrkFile(tractogram, ref=self.affine) - # trk.save(trk_file) - # trk_file.seek(0, os.SEEK_SET) - - # trk = TrkFile.load(trk_file, ref=None, lazy_load=False) - # check_tractogram(trk.tractogram, self.nb_streamlines, - # self.streamlines, self.colors, self.mean_curvature_torsion) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 23e14d5685..2ae433fd4b 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -287,7 +287,7 @@ def _apply_affine(): for s in streamlines_gen: yield apply_affine(self._affine_to_apply, s) - streamlines_gen = _apply_affine() + return _apply_affine() return streamlines_gen diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 278b6e2ebe..3570afa544 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -295,8 +295,20 @@ def write(self, tractogram): i4_dtype = np.dtype("i4") f4_dtype = np.dtype("f4") + try: + first_item = next(iter(tractogram)) + except StopIteration: + # Empty tractogram + self.header[Field.NB_STREAMLINES] = 0 + self.header[Field.NB_SCALARS_PER_POINT] = 0 + self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = 0 + # Overwrite header with updated one. + self.file.seek(self.beginning, os.SEEK_SET) + self.file.write(self.header.tostring()) + return + # Update the 'property_name' field using 'data_per_streamline' of the tractogram. - data_for_streamline = tractogram[0].data_for_streamline + data_for_streamline = first_item.data_for_streamline data_for_streamline_keys = sorted(data_for_streamline.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] self.header['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') for i, k in enumerate(data_for_streamline_keys): @@ -311,7 +323,7 @@ def write(self, tractogram): self.header['property_name'][i] = property_name # Update the 'scalar_name' field using 'data_per_point' of the tractogram. - data_for_points = tractogram[0].data_for_points + data_for_points = first_item.data_for_points data_for_points_keys = sorted(data_for_points.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') for i, k in enumerate(data_for_points_keys): @@ -636,7 +648,7 @@ def load(cls, fileobj, lazy_load=False): if len(scalar_name) == 0: continue - nb_scalars = np.fromstring(scalar_name[-1], np.int8) + nb_scalars = int(np.fromstring(scalar_name[-1], np.int8)) scalar_name = scalar_name.split('\x00')[0] data_per_point_slice[scalar_name] = slice(cpt, cpt+nb_scalars) cpt += nb_scalars @@ -651,7 +663,7 @@ def load(cls, fileobj, lazy_load=False): if len(property_name) == 0: continue - nb_properties = np.fromstring(property_name[-1], np.int8) + nb_properties = int(np.fromstring(property_name[-1], np.int8)) property_name = property_name.split('\x00')[0] data_per_streamline_slice[property_name] = slice(cpt, cpt+nb_properties) cpt += nb_properties @@ -661,8 +673,8 @@ def load(cls, fileobj, lazy_load=False): def _read(): for pts, scals, props in trk_reader: - data_for_streamline = {k: props[:, v] for k, v in data_per_streamline_slice.items()} data_for_points = {k: scals[:, v] for k, v in data_per_point_slice.items()} + data_for_streamline = {k: props[v] for k, v in data_per_streamline_slice.items()} yield TractogramItem(pts, data_for_streamline, data_for_points) tractogram = LazyTractogram.create_from(_read) @@ -735,8 +747,7 @@ def save(self, fileobj): trk_writer = TrkWriter(fileobj, self.header) trk_writer.write(self.tractogram) - @staticmethod - def pretty_print(fileobj): + def __str__(self): ''' Gets a formatted string of the header of a TRK file. Parameters @@ -751,32 +762,32 @@ def pretty_print(fileobj): info : string Header information relevant to the TRK format. ''' - trk_reader = TrkReader(fileobj) - hdr = trk_reader.header + #trk_reader = TrkReader(fileobj) + hdr = self.header info = "" - info += "MAGIC NUMBER: {0}".format(hdr[Field.MAGIC_NUMBER]) - info += "v.{0}".format(hdr['version']) - info += "dim: {0}".format(hdr[Field.DIMENSIONS]) - info += "voxel_sizes: {0}".format(hdr[Field.VOXEL_SIZES]) - info += "orgin: {0}".format(hdr[Field.ORIGIN]) - info += "nb_scalars: {0}".format(hdr[Field.NB_SCALARS_PER_POINT]) - info += "scalar_name:\n {0}".format("\n".join(hdr['scalar_name'])) - info += "nb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) - info += "property_name:\n {0}".format("\n".join(hdr['property_name'])) - info += "vox_to_world: {0}".format(hdr[Field.VOXEL_TO_RASMM]) - info += "voxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) - info += "image_orientation_patient: {0}".format(hdr['image_orientation_patient']) - info += "pad1: {0}".format(hdr['pad1']) - info += "pad2: {0}".format(hdr['pad2']) - info += "invert_x: {0}".format(hdr['invert_x']) - info += "invert_y: {0}".format(hdr['invert_y']) - info += "invert_z: {0}".format(hdr['invert_z']) - info += "swap_xy: {0}".format(hdr['swap_xy']) - info += "swap_yz: {0}".format(hdr['swap_yz']) - info += "swap_zx: {0}".format(hdr['swap_zx']) - info += "n_count: {0}".format(hdr[Field.NB_STREAMLINES]) - info += "hdr_size: {0}".format(hdr['hdr_size']) - info += "endianess: {0}".format(hdr[Field.ENDIAN]) + info += "\nMAGIC NUMBER: {0}".format(hdr[Field.MAGIC_NUMBER]) + info += "\nv.{0}".format(hdr['version']) + info += "\ndim: {0}".format(hdr[Field.DIMENSIONS]) + info += "\nvoxel_sizes: {0}".format(hdr[Field.VOXEL_SIZES]) + info += "\norgin: {0}".format(hdr[Field.ORIGIN]) + info += "\nnb_scalars: {0}".format(hdr[Field.NB_SCALARS_PER_POINT]) + info += "\nscalar_name:\n {0}".format("\n".join(hdr['scalar_name'])) + info += "\nnb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + info += "\nproperty_name:\n {0}".format("\n".join(hdr['property_name'])) + info += "\nvox_to_world: {0}".format(hdr[Field.VOXEL_TO_RASMM]) + info += "\nvoxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) + info += "\nimage_orientation_patient: {0}".format(hdr['image_orientation_patient']) + info += "\npad1: {0}".format(hdr['pad1']) + info += "\npad2: {0}".format(hdr['pad2']) + info += "\ninvert_x: {0}".format(hdr['invert_x']) + info += "\ninvert_y: {0}".format(hdr['invert_y']) + info += "\ninvert_z: {0}".format(hdr['invert_z']) + info += "\nswap_xy: {0}".format(hdr['swap_xy']) + info += "\nswap_yz: {0}".format(hdr['swap_yz']) + info += "\nswap_zx: {0}".format(hdr['swap_zx']) + info += "\nn_count: {0}".format(hdr[Field.NB_STREAMLINES]) + info += "\nhdr_size: {0}".format(hdr['hdr_size']) + #info += "endianess: {0}".format(hdr[Field.ENDIAN]) return info From be23f7643e9251167fe9f5aa8df409f38cf58b70 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 Nov 2015 21:11:20 -0500 Subject: [PATCH 038/135] Fixed some unit tests --- nibabel/streamlines/tests/data/empty.trk | Bin 1000 -> 1000 bytes nibabel/streamlines/tests/test_streamlines.py | 213 ++++++++++-------- nibabel/streamlines/tests/test_tractogram.py | 32 +++ nibabel/streamlines/tests/test_trk.py | 114 ++++------ 4 files changed, 190 insertions(+), 169 deletions(-) diff --git a/nibabel/streamlines/tests/data/empty.trk b/nibabel/streamlines/tests/data/empty.trk index e78e28403b087f627c298afc88c200dc2d50bcff..fbe087180770b69b1e4ff2c5f3ab1949f4a46a89 100644 GIT binary patch delta 20 ZcmaFC{(_x5B*@X(n}HDoH*())1^_$x1pWX3 delta 20 ccmaFC{(_x5B*@X(n_&UN35Eq5x$iOq07@JO0{{R3 diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index e27e7cc42a..dfd5026102 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -12,6 +12,7 @@ from nibabel.testing import assert_arrays_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true, assert_false +from .test_tractogram import assert_tractogram_equal from ..tractogram import Tractogram, LazyTractogram from ..tractogram_file import TractogramFile from ..tractogram import UsageWarning @@ -20,22 +21,6 @@ DATA_PATH = pjoin(os.path.dirname(__file__), 'data') -def check_tractogram(tractogram, nb_streamlines, streamlines, data_per_streamline, data_per_point): - # Check data - assert_equal(len(tractogram), nb_streamlines) - assert_arrays_equal(tractogram.streamlines, streamlines) - - for key in data_per_streamline.keys(): - assert_arrays_equal(tractogram.data_per_streamline[key], - data_per_streamline[key]) - - for key in data_per_point.keys(): - assert_arrays_equal(tractogram.data_per_point[key], - data_per_point[key]) - - assert_true(isiterable(tractogram)) - - def test_is_supported(): # Emtpy file/string f = BytesIO() @@ -112,121 +97,153 @@ def test_detect_format(): class TestLoadSave(unittest.TestCase): def setUp(self): - self.empty_filenames = [pjoin(DATA_PATH, "empty" + ext) for ext in nib.streamlines.FORMATS.keys()] - self.simple_filenames = [pjoin(DATA_PATH, "simple" + ext) for ext in nib.streamlines.FORMATS.keys()] - self.complex_filenames = [pjoin(DATA_PATH, "complex" + ext) for ext in nib.streamlines.FORMATS.keys()] + self.empty_filenames = [pjoin(DATA_PATH, "empty" + ext) + for ext in nib.streamlines.FORMATS.keys()] + self.simple_filenames = [pjoin(DATA_PATH, "simple" + ext) + for ext in nib.streamlines.FORMATS.keys()] + self.complex_filenames = [pjoin(DATA_PATH, "complex" + ext) + for ext in nib.streamlines.FORMATS.keys()] self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), np.arange(2*3, dtype="f4").reshape((2, 3)), np.arange(5*3, dtype="f4").reshape((5, 3))] + self.fa = [np.array([[0.2]], dtype="f4"), + np.array([[0.3], + [0.4]], dtype="f4"), + np.array([[0.5], + [0.6], + [0.6], + [0.7], + [0.8]], dtype="f4")] + self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), np.array([(0, 1, 0)]*2, dtype="f4"), np.array([(0, 0, 1)]*5, dtype="f4")] - self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), - np.array([2.11, 2.22], dtype="f4"), - np.array([3.11, 3.22], dtype="f4")] + self.mean_curvature = [np.array([1.11], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11], dtype="f4")] - self.data_per_point = {'scalars': self.colors} - self.data_per_streamline = {'properties': self.mean_curvature_torsion} + self.mean_torsion = [np.array([1.22], dtype="f4"), + np.array([2.22], dtype="f4"), + np.array([3.22], dtype="f4")] - self.nb_streamlines = len(self.streamlines) - self.nb_scalars_per_point = self.colors[0].shape[1] - self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) - self.to_world_space = np.eye(4) + self.mean_colors = [np.array([1, 0, 0], dtype="f4"), + np.array([0, 1, 0], dtype="f4"), + np.array([0, 0, 1], dtype="f4")] + + self.data_per_point = {'colors': self.colors, + 'fa': self.fa} + self.data_per_streamline = {'mean_curvature': self.mean_curvature, + 'mean_torsion': self.mean_torsion, + 'mean_colors': self.mean_colors} + + self.empty_tractogram = Tractogram() + self.simple_tractogram = Tractogram(self.streamlines) + self.complex_tractogram = Tractogram(self.streamlines, + self.data_per_streamline, + self.data_per_point) + + #self.nb_streamlines = len(self.streamlines) + #self.nb_scalars_per_point = self.colors[0].shape[1] + #self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) + #self.to_world_space = np.eye(4) def test_load_empty_file(self): - for empty_filename in self.empty_filenames: - tractogram_file = nib.streamlines.load(empty_filename, - lazy_load=False) - assert_true(isinstance(tractogram_file, TractogramFile)) - assert_true(type(tractogram_file.tractogram), Tractogram) - check_tractogram(tractogram_file.tractogram, 0, [], {}, {}) + for lazy_load in [False, True]: + for empty_filename in self.empty_filenames: + tfile = nib.streamlines.load(empty_filename, + lazy_load=lazy_load) + assert_true(isinstance(tfile, TractogramFile)) + + if lazy_load: + assert_true(type(tfile.tractogram), Tractogram) + else: + assert_true(type(tfile.tractogram), LazyTractogram) + + assert_tractogram_equal(tfile.tractogram, + self.empty_tractogram) def test_load_simple_file(self): - for simple_filename in self.simple_filenames: - tractogram_file = nib.streamlines.load(simple_filename, - lazy_load=False) - assert_true(isinstance(tractogram_file, TractogramFile)) - assert_true(type(tractogram_file.tractogram), Tractogram) - check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, {}, {}) - - # # Test lazy_load - # tractogram_file = nib.streamlines.load(simple_filename, - # lazy_load=True) - - # assert_true(isinstance(tractogram_file, TractogramFile)) - # assert_true(type(tractogram_file.tractogram), LazyTractogram) - # check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - # self.streamlines, {}, {}) + for lazy_load in [False, True]: + for simple_filename in self.simple_filenames: + tfile = nib.streamlines.load(simple_filename, + lazy_load=lazy_load) + assert_true(isinstance(tfile, TractogramFile)) + + if lazy_load: + assert_true(type(tfile.tractogram), Tractogram) + else: + assert_true(type(tfile.tractogram), LazyTractogram) + + assert_tractogram_equal(tfile.tractogram, + self.simple_tractogram) def test_load_complex_file(self): - for complex_filename in self.complex_filenames: - file_format = nib.streamlines.detect_format(complex_filename) - - data_per_point = {} - if file_format.support_data_per_point(): - data_per_point = {'scalars': self.colors} - - data_per_streamline = [] - if file_format.support_data_per_streamline(): - data_per_streamline = {'properties': self.mean_curvature_torsion} - - tractogram_file = nib.streamlines.load(complex_filename, - lazy_load=False) - assert_true(isinstance(tractogram_file, TractogramFile)) - assert_true(type(tractogram_file.tractogram), Tractogram) - check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, - data_per_streamline=data_per_streamline, - data_per_point=data_per_point) - - # # Test lazy_load - # tractogram_file = nib.streamlines.load(complex_filename, - # lazy_load=True) - # assert_true(isinstance(tractogram_file, TractogramFile)) - # assert_true(type(tractogram_file.tractogram), LazyTractogram) - # check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - # self.streamlines, - # data_per_streamline=data_per_streamline, - # data_per_point=data_per_point) + for lazy_load in [False, True]: + for complex_filename in self.complex_filenames: + tfile = nib.streamlines.load(complex_filename, + lazy_load=lazy_load) + assert_true(isinstance(tfile, TractogramFile)) + + if lazy_load: + assert_true(type(tfile.tractogram), Tractogram) + else: + assert_true(type(tfile.tractogram), LazyTractogram) + + tractogram = Tractogram(self.streamlines) + + if tfile.support_data_per_point(): + tractogram.data_per_point = self.data_per_point + + if tfile.support_data_per_streamline(): + tractogram.data_per_streamline = self.data_per_streamline + + assert_tractogram_equal(tfile.tractogram, + tractogram) + + def test_save_empty_file(self): + tractogram = Tractogram() + for ext, cls in nib.streamlines.FORMATS.items(): + with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: + nib.streamlines.save_tractogram(tractogram, f.name) + tfile = nib.streamlines.load(f, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_simple_file(self): tractogram = Tractogram(self.streamlines) for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: nib.streamlines.save_tractogram(tractogram, f.name) - tractogram_file = nib.streamlines.load(f, lazy_load=False) - check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, {}, {}) + tfile = nib.streamlines.load(f, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_complex_file(self): - tractogram = Tractogram(self.streamlines, - data_per_streamline=self.data_per_streamline, - data_per_point=self.data_per_point) + complex_tractogram = Tractogram(self.streamlines, + self.data_per_streamline, + self.data_per_point) + for ext, cls in nib.streamlines.FORMATS.items(): with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: with clear_and_catch_warnings(record=True, modules=[trk]) as w: - nib.streamlines.save_tractogram(tractogram, f.name) + nib.streamlines.save_tractogram(complex_tractogram, f.name) - # If streamlines format does not support saving data per point - # or data per streamline, a warning message should be issued. - if not (cls.support_data_per_point() and cls.support_data_per_streamline()): + # If streamlines format does not support saving data per + # point or data per streamline, a warning message should + # be issued. + if not (cls.support_data_per_point() + and cls.support_data_per_streamline()): assert_equal(len(w), 1) assert_true(issubclass(w[0].category, UsageWarning)) - data_per_point = {} + tractogram = Tractogram(self.streamlines) + if cls.support_data_per_point(): - data_per_point = self.data_per_point + tractogram.data_per_point = self.data_per_point - data_per_streamline = [] if cls.support_data_per_streamline(): - data_per_streamline = self.data_per_streamline + tractogram.data_per_streamline = self.data_per_streamline - tractogram_file = nib.streamlines.load(f, ref=self.to_world_space, lazy_load=False) - check_tractogram(tractogram_file.tractogram, self.nb_streamlines, - self.streamlines, - data_per_streamline=data_per_streamline, - data_per_point=data_per_point) + tfile = nib.streamlines.load(f, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 493be83301..ac458bae32 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -13,6 +13,38 @@ from ..tractogram import TractogramItem, Tractogram, LazyTractogram +# def check_tractogram(tractogram, streamlines, +# data_per_streamline, data_per_point): +# # Check data +# assert_true(isiterable(tractogram)) +# assert_equal(len(tractogram), len(streamlines)) +# assert_arrays_equal(tractogram.streamlines, streamlines) + +# for key in data_per_streamline.keys(): +# assert_arrays_equal(tractogram.data_per_streamline[key], +# data_per_streamline[key]) + +# for key in data_per_point.keys(): +# assert_arrays_equal(tractogram.data_per_point[key], +# data_per_point[key]) + + +def assert_tractogram_equal(t1, t2): + assert_true(isiterable(t1)) + assert_equal(len(t1), len(t2)) + assert_arrays_equal(t1.streamlines, t2.streamlines) + + assert_equal(len(t1.data_per_streamline), len(t2.data_per_streamline)) + for key in t1.data_per_streamline.keys(): + assert_arrays_equal(t1.data_per_streamline[key], + t2.data_per_streamline[key]) + + assert_equal(len(t1.data_per_point), len(t2.data_per_point)) + for key in t1.data_per_point.keys(): + assert_arrays_equal(t1.data_per_point[key], + t2.data_per_point[key]) + + class TestTractogramItem(unittest.TestCase): def test_creating_tractogram_item(self): diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 4abd5b688c..cbd0ffc785 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -8,6 +8,7 @@ from nibabel.testing import assert_arrays_equal, isiterable from nose.tools import assert_equal, assert_raises, assert_true +from .test_tractogram import assert_tractogram_equal from .. import base_format from ..tractogram import Tractogram, LazyTractogram from ..base_format import DataError, HeaderError, HeaderWarning#, UsageWarning @@ -31,36 +32,6 @@ def assert_header_equal(h1, h2): assert_equal(header1, header2) -def assert_tractogram_equal(t1, t2): - assert_equal(len(t1), len(t2)) - assert_arrays_equal(t1.streamlines, t2.streamlines) - - assert_equal(len(t1.data_per_streamline), len(t2.data_per_streamline)) - for key in t1.data_per_streamline.keys(): - assert_arrays_equal(t1.data_per_streamline[key], - t2.data_per_streamline[key]) - - assert_equal(len(t1.data_per_point), len(t2.data_per_point)) - for key in t1.data_per_point.keys(): - assert_arrays_equal(t1.data_per_point[key], - t2.data_per_point[key]) - - -def check_tractogram(tractogram, streamlines, data_per_streamline, data_per_point): - # Check data - assert_true(isiterable(tractogram)) - assert_equal(len(tractogram), len(streamlines)) - assert_arrays_equal(tractogram.streamlines, streamlines) - - for key in data_per_streamline.keys(): - assert_arrays_equal(tractogram.data_per_streamline[key], - data_per_streamline[key]) - - for key in data_per_point.keys(): - assert_arrays_equal(tractogram.data_per_point[key], - data_per_point[key]) - - class TestTRK(unittest.TestCase): def setUp(self): @@ -101,38 +72,32 @@ def setUp(self): np.array([0, 1, 0], dtype="f4"), np.array([0, 0, 1], dtype="f4")] - self.mean_curvature_torsion = [np.array([1.11, 1.22], dtype="f4"), - np.array([2.11, 2.22], dtype="f4"), - np.array([3.11, 3.22], dtype="f4")] - self.data_per_point = {'colors': self.colors, 'fa': self.fa} self.data_per_streamline = {'mean_curvature': self.mean_curvature, 'mean_torsion': self.mean_torsion, 'mean_colors': self.mean_colors} - self.nb_streamlines = len(self.streamlines) - self.nb_scalars_per_point = self.colors[0].shape[1] - self.nb_properties_per_streamline = len(self.mean_curvature_torsion[0]) - self.affine = np.eye(4) + self.empty_tractogram = Tractogram() + self.simple_tractogram = Tractogram(self.streamlines) + self.complex_tractogram = Tractogram(self.streamlines, + self.data_per_streamline, + self.data_per_point) def test_load_empty_file(self): for lazy_load in [False, True]: trk = TrkFile.load(self.empty_trk_filename, lazy_load=lazy_load) - check_tractogram(trk.tractogram, [], {}, {}) + assert_tractogram_equal(trk.tractogram, self.empty_tractogram) def test_load_simple_file(self): for lazy_load in [False, True]: trk = TrkFile.load(self.simple_trk_filename, lazy_load=lazy_load) - check_tractogram(trk.tractogram, self.streamlines, {}, {}) + assert_tractogram_equal(trk.tractogram, self.simple_tractogram) def test_load_complex_file(self): for lazy_load in [False, True]: trk = TrkFile.load(self.complex_trk_filename, lazy_load=lazy_load) - check_tractogram(trk.tractogram, - self.streamlines, - data_per_point=self.data_per_point, - data_per_streamline=self.data_per_streamline) + assert_tractogram_equal(trk.tractogram, self.complex_tractogram) def test_load_file_with_wrong_information(self): trk_file = open(self.simple_trk_filename, 'rb').read() @@ -141,7 +106,7 @@ def test_load_file_with_wrong_information(self): count = np.array(0, dtype="int32").tostring() new_trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) - check_tractogram(trk.tractogram, self.streamlines, {}, {}) + assert_tractogram_equal(trk.tractogram, self.simple_tractogram) # Simulate a TRK file where `voxel_order` was not provided. voxel_order = np.zeros(1, dtype="|S3").tostring() @@ -162,23 +127,39 @@ def test_load_file_with_wrong_information(self): new_trk_file = trk_file[:996] + hdr_size + trk_file[996+4:] assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + def test_write_empty_file(self): + tractogram = Tractogram() + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + loaded_trk = TrkFile.load(trk_file) + assert_tractogram_equal(loaded_trk.tractogram, tractogram) + + loaded_trk_orig = TrkFile.load(self.empty_trk_filename) + assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) + + trk_file.seek(0, os.SEEK_SET) + assert_equal(trk_file.read(), open(self.empty_trk_filename, 'rb').read()) + def test_write_simple_file(self): tractogram = Tractogram(self.streamlines) trk_file = BytesIO() - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file) - check_tractogram(loaded_trk.tractogram, - self.streamlines, {}, {}) + assert_tractogram_equal(loaded_trk.tractogram, tractogram) loaded_trk_orig = TrkFile.load(self.simple_trk_filename) assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) - assert_equal(open(self.simple_trk_filename, 'rb').read(), trk_file.read()) + assert_equal(trk_file.read(), open(self.simple_trk_filename, 'rb').read()) def test_write_complex_file(self): # With scalars @@ -186,30 +167,24 @@ def test_write_complex_file(self): data_per_point=self.data_per_point) trk_file = BytesIO() - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file, lazy_load=False) - check_tractogram(loaded_trk.tractogram, - self.streamlines, - data_per_streamline={}, - data_per_point=self.data_per_point) + assert_tractogram_equal(loaded_trk.tractogram, tractogram) # With properties tractogram = Tractogram(self.streamlines, data_per_streamline=self.data_per_streamline) - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) trk_file = BytesIO() trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file, lazy_load=False) - check_tractogram(loaded_trk.tractogram, - self.streamlines, - data_per_streamline=self.data_per_streamline, - data_per_point={}) + assert_tractogram_equal(loaded_trk.tractogram, tractogram) # With scalars and properties tractogram = Tractogram(self.streamlines, @@ -217,21 +192,18 @@ def test_write_complex_file(self): data_per_streamline=self.data_per_streamline) trk_file = BytesIO() - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) loaded_trk = TrkFile.load(trk_file, lazy_load=False) - check_tractogram(loaded_trk.tractogram, - self.streamlines, - data_per_streamline=self.data_per_streamline, - data_per_point=self.data_per_point) + assert_tractogram_equal(loaded_trk.tractogram, tractogram) loaded_trk_orig = TrkFile.load(self.complex_trk_filename) assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) - assert_equal(open(self.complex_trk_filename, 'rb').read(), trk_file.read()) + assert_equal(trk_file.read(), open(self.complex_trk_filename, 'rb').read()) def test_write_erroneous_file(self): # No scalars for every points @@ -241,7 +213,7 @@ def test_write_erroneous_file(self): tractogram = Tractogram(self.streamlines, data_per_point={'scalars': scalars}) - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) assert_raises(DataError, trk.save, BytesIO()) # No scalars for every streamlines @@ -250,7 +222,7 @@ def test_write_erroneous_file(self): tractogram = Tractogram(self.streamlines, data_per_point={'scalars': scalars}) - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) assert_raises(IndexError, trk.save, BytesIO()) # Inconsistent number of properties @@ -259,7 +231,7 @@ def test_write_erroneous_file(self): np.array([3.11, 3.22], dtype="f4")] tractogram = Tractogram(self.streamlines, data_per_streamline={'properties': properties}) - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) assert_raises(DataError, trk.save, BytesIO()) # No properties for every streamlines @@ -267,7 +239,7 @@ def test_write_erroneous_file(self): np.array([2.11, 2.22], dtype="f4")] tractogram = Tractogram(self.streamlines, data_per_streamline={'properties': properties}) - trk = TrkFile(tractogram, ref=self.affine) + trk = TrkFile(tractogram) assert_raises(IndexError, trk.save, BytesIO()) def test_load_write_file(self): @@ -281,7 +253,7 @@ def test_load_write_file(self): assert_tractogram_equal(loaded_trk.tractogram, trk.tractogram) trk_file.seek(0, os.SEEK_SET) - #assert_equal(open(filename, 'rb').read(), trk_file.read()) + #assert_equal(trk_file.read(), open(filename, 'rb').read()) def test_load_write_LPS_file(self): trk = TrkFile.load(self.simple_LPS_trk_filename, lazy_load=False) @@ -300,4 +272,4 @@ def test_load_write_LPS_file(self): assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) - assert_equal(open(self.simple_LPS_trk_filename, 'rb').read(), trk_file.read()) + assert_equal(trk_file.read(), open(self.simple_LPS_trk_filename, 'rb').read()) From 7f917557950aa4b96d30f21728602faf26789adb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 24 Nov 2015 00:25:23 -0500 Subject: [PATCH 039/135] BF: limit property and scalar names to 18 characters --- nibabel/streamlines/tests/test_tractogram.py | 16 -- nibabel/streamlines/tests/test_trk.py | 134 +++++++++++--- nibabel/streamlines/trk.py | 181 ++++++------------- 3 files changed, 165 insertions(+), 166 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index ac458bae32..64b0e150ab 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -13,22 +13,6 @@ from ..tractogram import TractogramItem, Tractogram, LazyTractogram -# def check_tractogram(tractogram, streamlines, -# data_per_streamline, data_per_point): -# # Check data -# assert_true(isiterable(tractogram)) -# assert_equal(len(tractogram), len(streamlines)) -# assert_arrays_equal(tractogram.streamlines, streamlines) - -# for key in data_per_streamline.keys(): -# assert_arrays_equal(tractogram.data_per_streamline[key], -# data_per_streamline[key]) - -# for key in data_per_point.keys(): -# assert_arrays_equal(tractogram.data_per_point[key], -# data_per_point[key]) - - def assert_tractogram_equal(t1, t2): assert_true(isiterable(t1)) assert_equal(len(t1), len(t2)) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index cbd0ffc785..4b9252e00d 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -13,7 +13,7 @@ from ..tractogram import Tractogram, LazyTractogram from ..base_format import DataError, HeaderError, HeaderWarning#, UsageWarning -#from .. import trk +from .. import trk as trk_module from ..trk import TrkFile, header_2_dtype DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') @@ -111,7 +111,7 @@ def test_load_file_with_wrong_information(self): # Simulate a TRK file where `voxel_order` was not provided. voxel_order = np.zeros(1, dtype="|S3").tostring() new_trk_file = trk_file[:948] + voxel_order + trk_file[948+3:] - with clear_and_catch_warnings(record=True, modules=[trk]) as w: + with clear_and_catch_warnings(record=True, modules=[trk_module]) as w: TrkFile.load(BytesIO(new_trk_file)) assert_equal(len(w), 1) assert_true(issubclass(w[0].category, HeaderWarning)) @@ -135,11 +135,11 @@ def test_write_empty_file(self): trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_trk = TrkFile.load(trk_file) - assert_tractogram_equal(loaded_trk.tractogram, tractogram) + new_trk = TrkFile.load(trk_file) + assert_tractogram_equal(new_trk.tractogram, tractogram) - loaded_trk_orig = TrkFile.load(self.empty_trk_filename) - assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) + new_trk_orig = TrkFile.load(self.empty_trk_filename) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), open(self.empty_trk_filename, 'rb').read()) @@ -152,11 +152,11 @@ def test_write_simple_file(self): trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_trk = TrkFile.load(trk_file) - assert_tractogram_equal(loaded_trk.tractogram, tractogram) + new_trk = TrkFile.load(trk_file) + assert_tractogram_equal(new_trk.tractogram, tractogram) - loaded_trk_orig = TrkFile.load(self.simple_trk_filename) - assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) + new_trk_orig = TrkFile.load(self.simple_trk_filename) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), open(self.simple_trk_filename, 'rb').read()) @@ -171,8 +171,8 @@ def test_write_complex_file(self): trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_trk = TrkFile.load(trk_file, lazy_load=False) - assert_tractogram_equal(loaded_trk.tractogram, tractogram) + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) # With properties tractogram = Tractogram(self.streamlines, @@ -183,8 +183,8 @@ def test_write_complex_file(self): trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_trk = TrkFile.load(trk_file, lazy_load=False) - assert_tractogram_equal(loaded_trk.tractogram, tractogram) + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) # With scalars and properties tractogram = Tractogram(self.streamlines, @@ -196,11 +196,11 @@ def test_write_complex_file(self): trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_trk = TrkFile.load(trk_file, lazy_load=False) - assert_tractogram_equal(loaded_trk.tractogram, tractogram) + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) - loaded_trk_orig = TrkFile.load(self.complex_trk_filename) - assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) + new_trk_orig = TrkFile.load(self.complex_trk_filename) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), open(self.complex_trk_filename, 'rb').read()) @@ -249,8 +249,8 @@ def test_load_write_file(self): trk_file = BytesIO() trk.save(trk_file) - loaded_trk = TrkFile.load(filename, lazy_load=False) - assert_tractogram_equal(loaded_trk.tractogram, trk.tractogram) + new_trk = TrkFile.load(filename, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, trk.tractogram) trk_file.seek(0, os.SEEK_SET) #assert_equal(trk_file.read(), open(filename, 'rb').read()) @@ -263,13 +263,97 @@ def test_load_write_LPS_file(self): trk.save(trk_file) trk_file.seek(0, os.SEEK_SET) - loaded_trk = TrkFile.load(trk_file) + new_trk = TrkFile.load(trk_file) - assert_header_equal(loaded_trk.header, trk.header) - assert_tractogram_equal(loaded_trk.tractogram, trk.tractogram) + assert_header_equal(new_trk.header, trk.header) + assert_tractogram_equal(new_trk.tractogram, trk.tractogram) - loaded_trk_orig = TrkFile.load(self.simple_LPS_trk_filename) - assert_tractogram_equal(loaded_trk.tractogram, loaded_trk_orig.tractogram) + new_trk_orig = TrkFile.load(self.simple_LPS_trk_filename) + assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), open(self.simple_LPS_trk_filename, 'rb').read()) + + def test_write_too_many_scalars_and_properties(self): + # TRK supports up to 10 data_per_point. + data_per_point = {} + for i in range(10): + data_per_point['#{0}'.format(i)] = self.fa + + tractogram = Tractogram(self.streamlines, + data_per_point=data_per_point) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + # More than 10 data_per_point should raise an error. + data_per_point['#{0}'.format(i+1)] = self.fa + + tractogram = Tractogram(self.streamlines, + data_per_point=data_per_point) + + trk = TrkFile(tractogram) + assert_raises(ValueError, trk.save, BytesIO()) + + # TRK supports up to 10 data_per_streamline. + data_per_streamline = {} + for i in range(10): + data_per_streamline['#{0}'.format(i)] = self.mean_torsion + + tractogram = Tractogram(self.streamlines, + data_per_streamline=data_per_streamline) + + trk_file = BytesIO() + trk = TrkFile(tractogram) + trk.save(trk_file) + trk_file.seek(0, os.SEEK_SET) + + new_trk = TrkFile.load(trk_file, lazy_load=False) + assert_tractogram_equal(new_trk.tractogram, tractogram) + + # More than 10 data_per_streamline should raise an error. + data_per_streamline['#{0}'.format(i+1)] = self.mean_torsion + + tractogram = Tractogram(self.streamlines, + data_per_streamline=data_per_streamline) + + trk = TrkFile(tractogram) + assert_raises(ValueError, trk.save, BytesIO()) + + def test_write_scalars_and_properties_name_too_long(self): + # TRK supports data_per_point name up to 20 characters. + # However, we reserve the last two characters to store + # the number of values associated to each data_per_point. + # So in reality we allow name of 18 characters, otherwise + # the name is truncated and warning is issue. + for nb_chars in range(22): + data_per_point = {'A'*nb_chars: self.fa} + tractogram = Tractogram(self.streamlines, + data_per_point=data_per_point) + + trk = TrkFile(tractogram) + if nb_chars > 18: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) + + # TRK supports data_per_streamline name up to 20 characters. + # However, we reserve the last two characters to store + # the number of values associated to each data_per_streamline. + # So in reality we allow name of 18 characters, otherwise + # the name is truncated and warning is issue. + for nb_chars in range(22): + data_per_streamline = {'A'*nb_chars: self.mean_torsion} + tractogram = Tractogram(self.streamlines, + data_per_streamline=data_per_streamline) + + trk = TrkFile(tractogram) + if nb_chars > 18: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 3570afa544..5caa098abf 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -21,7 +21,6 @@ from .tractogram import TractogramItem, Tractogram, LazyTractogram from .header import Field -from .utils import get_affine_from_reference MAX_NB_NAMED_SCALARS_PER_POINT = 10 MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE = 10 @@ -217,45 +216,6 @@ def create_empty_header(cls): return header - # def __init__(self, fileobj, header): - # self.header = self.create_empty_header() - - # # Override hdr's fields by those contained in `header`. - # for k, v in header.extra.items(): - # if k in header_2_dtype.fields.keys(): - # self.header[k] = v - - # # TODO: Fix that ugly patch. - # # Because the assignment operator on ndarray of string only copy the - # # first entry, we have to do it explicitly! - # if "property_name" in header.extra: - # self.header["property_name"][:] = header.extra["property_name"][:] - - # if "scalar_name" in header.extra: - # self.header["scalar_name"][:] = header.extra["scalar_name"][:] - - # self.header[Field.NB_STREAMLINES] = 0 - # if header.nb_streamlines is not None: - # self.header[Field.NB_STREAMLINES] = header.nb_streamlines - - # self.header[Field.NB_SCALARS_PER_POINT] = header.nb_scalars_per_point - # self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = header.nb_properties_per_streamline - # self.header[Field.VOXEL_SIZES] = header.voxel_sizes - # self.header[Field.VOXEL_TO_RASMM] = header.to_world_space - # self.header[Field.VOXEL_ORDER] = header.voxel_order - - # # Keep counts for correcting incoherent fields or warn. - # self.nb_streamlines = 0 - # self.nb_points = 0 - # self.nb_scalars = 0 - # self.nb_properties = 0 - - # # Write header - # self.file = Opener(fileobj, mode="wb") - # # Keep track of the beginning of the header. - # self.beginning = self.file.tell() - # self.file.write(self.header.tostring()) - def __init__(self, fileobj, header): self.header = self.create_empty_header() @@ -264,15 +224,6 @@ def __init__(self, fileobj, header): if k in header_2_dtype.fields.keys(): self.header[k] = v - # self.header[Field.NB_STREAMLINES] = 0 - # if header.nb_streamlines is not None: - # self.header[Field.NB_STREAMLINES] = header.nb_streamlines - - # self.header[Field.NB_SCALARS_PER_POINT] = header.nb_scalars_per_point - # self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = header.nb_properties_per_streamline - # self.header[Field.VOXEL_SIZES] = header.voxel_sizes - # self.header[Field.VOXEL_TO_RASMM] = header.to_world_space - # By default, the voxel order is LPS. # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates if self.header[Field.VOXEL_ORDER] == b"": @@ -309,35 +260,34 @@ def write(self, tractogram): # Update the 'property_name' field using 'data_per_streamline' of the tractogram. data_for_streamline = first_item.data_for_streamline - data_for_streamline_keys = sorted(data_for_streamline.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] + if len(data_for_streamline) > MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: + raise ValueError("Can only store {0} named data_per_streamline (properties).".format(MAX_NB_NAMED_SCALARS_PER_POINT)) + + data_for_streamline_keys = sorted(data_for_streamline.keys()) self.header['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') for i, k in enumerate(data_for_streamline_keys): - if i >= MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: - warnings.warn(("Can only store {0} named properties: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) - - if len(k) > 19: - warnings.warn(("Property name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) + if len(k) > 18: + raise ValueError("Property name '{0}' too long (max 18 char.)".format(k)) v = data_for_streamline[k] - property_name = k[:19].ljust(19, '\x00') + np.array(v.shape[0], dtype=np.int8).tostring() + property_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(v.shape[0], dtype=np.int8).tostring() self.header['property_name'][i] = property_name # Update the 'scalar_name' field using 'data_per_point' of the tractogram. data_for_points = first_item.data_for_points - data_for_points_keys = sorted(data_for_points.keys())[:MAX_NB_NAMED_SCALARS_PER_POINT] + if len(data_for_points) > MAX_NB_NAMED_SCALARS_PER_POINT: + raise ValueError("Can only store {0} named data_per_point (scalars).".format(MAX_NB_NAMED_SCALARS_PER_POINT)) + + data_for_points_keys = sorted(data_for_points.keys()) self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') for i, k in enumerate(data_for_points_keys): - if i >= MAX_NB_NAMED_SCALARS_PER_POINT: - warnings.warn(("Can only store {0} named scalars: '{1}' will be omitted.".format(MAX_NB_NAMED_SCALARS_PER_POINT, k)), HeaderWarning) - - if len(k) > 19: - warnings.warn(("Scalar name '{0}' has be truncated to {1}.".format(k, k[:19])), HeaderWarning) + if len(k) > 18: + raise ValueError("Scalar name '{0}' too long (max 18 char.)".format(k)) v = data_for_points[k] - scalar_name = k[:19].ljust(19, '\x00') + np.array(v.shape[1], dtype=np.int8).tostring() + scalar_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(v.shape[1], dtype=np.int8).tostring() self.header['scalar_name'][i] = scalar_name - # `Tractogram` streamlines are in RAS+ and mm space, we will compute # the affine matrix that will bring them back to 'voxelmm' as required # by the TRK format. @@ -638,39 +588,48 @@ def load(cls, fileobj, lazy_load=False): offset[:-1, -1] -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. affine = np.dot(offset, affine) - if lazy_load: - #pts, scals, props = next(iter(trk_reader)) - data_per_point_slice = {} - if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: - cpt = 0 - for scalar_name in trk_reader.header['scalar_name']: - if len(scalar_name) == 0: - continue + # Find scalars and properties name + data_per_point_slice = {} + if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: + cpt = 0 + for scalar_name in trk_reader.header['scalar_name']: + if len(scalar_name) == 0: + continue + # Check if we encoded the number of values we stocked for this scalar name. + nb_scalars = 1 + if scalar_name[-2] == '\x00' and scalar_name[-1] != '\x00': nb_scalars = int(np.fromstring(scalar_name[-1], np.int8)) - scalar_name = scalar_name.split('\x00')[0] - data_per_point_slice[scalar_name] = slice(cpt, cpt+nb_scalars) - cpt += nb_scalars - if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: - data_per_point_slice['scalars'] = slice(cpt, trk_reader.header[Field.NB_SCALARS_PER_POINT]) + scalar_name = scalar_name.split('\x00')[0] + data_per_point_slice[scalar_name] = slice(cpt, cpt+nb_scalars) + cpt += nb_scalars - data_per_streamline_slice = {} - if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: - cpt = 0 - for property_name in trk_reader.header['property_name']: - if len(property_name) == 0: - continue + if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: + data_per_point_slice['scalars'] = slice(cpt, trk_reader.header[Field.NB_SCALARS_PER_POINT]) + data_per_streamline_slice = {} + if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + cpt = 0 + for property_name in trk_reader.header['property_name']: + if len(property_name) == 0: + continue + + # Check if we encoded the number of values we stocked for this property name. + nb_properties = 1 + if property_name[-2] == '\x00' and property_name[-1] != '\x00': nb_properties = int(np.fromstring(property_name[-1], np.int8)) - property_name = property_name.split('\x00')[0] - data_per_streamline_slice[property_name] = slice(cpt, cpt+nb_properties) - cpt += nb_properties - if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: - data_per_streamline_slice['properties'] = slice(cpt, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]) + property_name = property_name.split('\x00')[0] + data_per_streamline_slice[property_name] = slice(cpt, cpt+nb_properties) + cpt += nb_properties + + if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: + data_per_streamline_slice['properties'] = slice(cpt, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]) + + if lazy_load: def _read(): for pts, scals, props in trk_reader: data_for_points = {k: scals[:, v] for k, v in data_per_point_slice.items()} @@ -683,43 +642,15 @@ def _read(): streamlines, scalars, properties = create_compactlist_from_generator(trk_reader) tractogram = Tractogram(streamlines) - if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: - cpt = 0 - for scalar_name in trk_reader.header['scalar_name']: - if len(scalar_name) == 0: - continue - - nb_scalars = np.fromstring(scalar_name[-1], np.int8) - - clist = CompactList() - clist._data = scalars._data[:, cpt:cpt+nb_scalars] - clist._offsets = scalars._offsets - clist._lengths = scalars._lengths - - scalar_name = scalar_name.split('\x00')[0] - tractogram.data_per_point[scalar_name] = clist - cpt += nb_scalars - - if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: - clist = CompactList() - clist._data = scalars._data[:, cpt:] - clist._offsets = scalars._offsets - clist._lengths = scalars._lengths - tractogram.data_per_point['scalars'] = clist - - if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: - cpt = 0 - for property_name in trk_reader.header['property_name']: - if len(property_name) == 0: - continue - - nb_properties = np.fromstring(property_name[-1], np.int8) - property_name = property_name.split('\x00')[0] - tractogram.data_per_streamline[property_name] = properties[:, cpt:cpt+nb_properties] - cpt += nb_properties - - if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: - tractogram.data_per_streamline['properties'] = properties[:, cpt:] + for scalar_name, slice_ in data_per_point_slice.items(): + clist = CompactList() + clist._data = scalars._data[:, slice_] + clist._offsets = scalars._offsets + clist._lengths = scalars._lengths + tractogram.data_per_point[scalar_name] = clist + + for property_name, slice_ in data_per_streamline_slice.items(): + tractogram.data_per_streamline[property_name] = properties[:, slice_] # Bring tractogram to RAS+ and mm space tractogram.apply_affine(affine.astype(np.float32)) From 4d93dde1ad028d7dfeb84f0c4837b58632d5fc48 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 24 Nov 2015 10:48:14 -0500 Subject: [PATCH 040/135] BF: Only store the nb of values in the property or scalar name if it is greater than one --- nibabel/streamlines/tests/data/complex.trk | Bin 1296 -> 1296 bytes nibabel/streamlines/tests/test_trk.py | 27 ++++++++++++++++--- nibabel/streamlines/trk.py | 30 +++++++++++++++------ 3 files changed, 46 insertions(+), 11 deletions(-) diff --git a/nibabel/streamlines/tests/data/complex.trk b/nibabel/streamlines/tests/data/complex.trk index 8aa099becd3bf9436d6cd0cc4ded087c3e794f8b..0a874ea6e74a7646c163b437f5423443a4816105 100644 GIT binary patch delta 51 zcmbQhHGyk_?_>d{g%jUt@-pP6Cg#PL 20: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) + # TRK supports data_per_streamline name up to 20 characters. # However, we reserve the last two characters to store # the number of values associated to each data_per_streamline. # So in reality we allow name of 18 characters, otherwise # the name is truncated and warning is issue. for nb_chars in range(22): - data_per_streamline = {'A'*nb_chars: self.mean_torsion} + data_per_streamline = {'A'*nb_chars: self.mean_colors} tractogram = Tractogram(self.streamlines, data_per_streamline=data_per_streamline) @@ -357,3 +368,13 @@ def test_write_scalars_and_properties_name_too_long(self): assert_raises(ValueError, trk.save, BytesIO()) else: trk.save(BytesIO()) + + data_per_streamline = {'A'*nb_chars: self.mean_torsion} + tractogram = Tractogram(self.streamlines, + data_per_streamline=data_per_streamline) + + trk = TrkFile(tractogram) + if nb_chars > 20: + assert_raises(ValueError, trk.save, BytesIO()) + else: + trk.save(BytesIO()) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 5caa098abf..3922f41e99 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -266,11 +266,18 @@ def write(self, tractogram): data_for_streamline_keys = sorted(data_for_streamline.keys()) self.header['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') for i, k in enumerate(data_for_streamline_keys): - if len(k) > 18: - raise ValueError("Property name '{0}' too long (max 18 char.)".format(k)) + nb_values = data_for_streamline[k].shape[0] + + if len(k) > 20: + raise ValueError("Property name '{0}' is too long (max 20 char.)".format(k)) + elif len(k) > 18 and nb_values > 1: + raise ValueError("Property name '{0}' is too long (need to be less than 18 characters when storing more than one value".format(k)) + + property_name = k + if nb_values > 1: + # Use the last to bytes of the name to store the nb of values associated to this data_for_streamline. + property_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(nb_values, dtype=np.int8).tostring() - v = data_for_streamline[k] - property_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(v.shape[0], dtype=np.int8).tostring() self.header['property_name'][i] = property_name # Update the 'scalar_name' field using 'data_per_point' of the tractogram. @@ -281,11 +288,18 @@ def write(self, tractogram): data_for_points_keys = sorted(data_for_points.keys()) self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') for i, k in enumerate(data_for_points_keys): - if len(k) > 18: - raise ValueError("Scalar name '{0}' too long (max 18 char.)".format(k)) + nb_values = data_for_points[k].shape[1] + + if len(k) > 20: + raise ValueError("Scalar name '{0}' is too long (max 18 char.)".format(k)) + elif len(k) > 18 and nb_values > 1: + raise ValueError("Scalar name '{0}' is too long (need to be less than 18 characters when storing more than one value".format(k)) + + scalar_name = k + if nb_values > 1: + # Use the last to bytes of the name to store the nb of values associated to this data_for_streamline. + scalar_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(nb_values, dtype=np.int8).tostring() - v = data_for_points[k] - scalar_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(v.shape[1], dtype=np.int8).tostring() self.header['scalar_name'][i] = scalar_name # `Tractogram` streamlines are in RAS+ and mm space, we will compute From 5d51c904f847647b55441b3a6c8db1c9d71de99d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 29 Nov 2015 01:02:28 -0500 Subject: [PATCH 041/135] Added a script to generate standard test object. --- .../streamlines/tests/data/gen_standard.py | 79 ++++++++++++++++++ .../streamlines/tests/data/standard.LPS.trk | Bin 0 -> 5800 bytes nibabel/streamlines/tests/data/standard.trk | Bin 0 -> 5800 bytes nibabel/streamlines/tests/test_streamlines.py | 5 -- nibabel/streamlines/tests/test_trk.py | 20 +++-- nibabel/streamlines/trk.py | 29 ++++--- 6 files changed, 107 insertions(+), 26 deletions(-) create mode 100644 nibabel/streamlines/tests/data/gen_standard.py create mode 100644 nibabel/streamlines/tests/data/standard.LPS.trk create mode 100644 nibabel/streamlines/tests/data/standard.trk diff --git a/nibabel/streamlines/tests/data/gen_standard.py b/nibabel/streamlines/tests/data/gen_standard.py new file mode 100644 index 0000000000..63b4173602 --- /dev/null +++ b/nibabel/streamlines/tests/data/gen_standard.py @@ -0,0 +1,79 @@ +import numpy as np +import nibabel as nib + +from nibabel.streamlines import FORMATS +from nibabel.streamlines.header import Field + + +def mark_the_spot(mask): + """ Marks every nonzero voxel using streamlines to form a 3D 'X' inside. + + Generates streamlines forming a 3D 'X' inside every nonzero voxel. + + Parameters + ---------- + mask : ndarray + Mask containing the spots to be marked. + + Returns + ------- + list of ndarrays + All streamlines needed to mark every nonzero voxel in the `mask`. + """ + def _gen_straight_streamline(start, end, steps=3): + coords = [] + for s, e in zip(start, end): + coords.append(np.linspace(s, e, steps)) + + return np.array(coords).T + + # Generate a 3D 'X' template fitting inside the voxel centered at (0,0,0). + X = [] + X.append(_gen_straight_streamline((-0.5, -0.5, -0.5), (0.5, 0.5, 0.5))) + X.append(_gen_straight_streamline((-0.5, 0.5, -0.5), (0.5, -0.5, 0.5))) + X.append(_gen_straight_streamline((-0.5, 0.5, 0.5), (0.5, -0.5, -0.5))) + X.append(_gen_straight_streamline((-0.5, -0.5, 0.5), (0.5, 0.5, -0.5))) + + # Get the coordinates of voxels 'on' in the mask. + coords = np.array(zip(*np.where(mask))) + + streamlines = [] + for c in coords: + for line in X: + streamlines.append((line + c) * voxel_size) + + return streamlines + + +rng = np.random.RandomState(42) + +width = 4 # Coronal +height = 5 # Sagittal +depth = 7 # Axial + +voxel_size = np.array((1., 3., 2.)) + +# Generate a random mask with voxel order RAS+. +mask = rng.rand(width, height, depth) > 0.8 +mask = (255*mask).astype(np.uint8) + +# Build tractogram +streamlines = mark_the_spot(mask) +tractogram = nib.streamlines.Tractogram(streamlines) + +# Build header +affine = np.eye(4) +affine[range(3), range(3)] = voxel_size +header = {Field.DIMENSIONS: (width, height, depth), + Field.VOXEL_SIZES: voxel_size, + Field.VOXEL_TO_RASMM: affine, + Field.VOXEL_ORDER: 'RAS'} + +# Save the standard mask. +nii = nib.Nifti1Image(mask, affine=affine) +nib.save(nii, "standard.nii.gz") + +# Save the standard tractogram in every available file format. +for ext, cls in FORMATS.items(): + tfile = cls(tractogram, header) + nib.streamlines.save(tfile, "standard" + ext) diff --git a/nibabel/streamlines/tests/data/standard.LPS.trk b/nibabel/streamlines/tests/data/standard.LPS.trk new file mode 100644 index 0000000000000000000000000000000000000000..ebda71bdb80c27dec2a8e61f22585ca9f56fff41 GIT binary patch literal 5800 zcmeH}Ka$lv5X7hWQ>1$YhlJb!y&{C*6%Y{m2pbam2zvyM;-hRx$Pwt)Xjc8k{Y`g|?C1Kvf4zMAY;WwXy)(1zqgiX1@Ol~l^a=1;<rB*FQnL!u9*L8UDlAn%VsY6Wmu^KQM2&_Y3o_O|G4Atk2w5=(*3- z^SavCvadP!o?hddbl3Aty}m9u&z-oDPxU%ylhb~cbDzW^3(t=n|IYK{eYs4o*A*On zGPj&?CVF01`##5gfOh0-qUU+1UU2lu9QUGM`f^=y ztj}EJQ$6>&^{qJfIp?n9{K&1+FW0IU9DOslO25i!A71C@$3B8{AJ0eT9$Fkr^WQIj zABH)z#P5h)Ut=qCy}nI7uRGyva$a}j+-GvC$KMCDFy9^K@!jAa*W6!lDaZQEZ6{pr zx8mGq=2PE4&~3O3*N#hlx#s?YOS#;K*VUZn^SUGFKDYU}7ru{}8!qnIc>mg1b8o@L ze2=&cSMM)zhBK|WUETw)J90ZTpW}L6XHff$Z*aV4ZjpS-X}?}K^99Fznd3VY{a!J* zk#9};Dvr9DTMIpMP4v93_EnB~!}kI6$k#AeU#=^T^_h!&s^>oQz4AKGr=0stj%yqJ z@;k)7jC&Z1^xHQIj`f*aPB;@iuk(D_H*;|=@tur(tLQOb#nC5otK`$?(L~SdX1?Iu z$Ip*B?xD5g&TRbut4GW?%$X(nAjfsZu|9J~HtD&~t#8G-&*apX-`&V}mwdU7e9Ez2 z^Ld@;%UpanqTk4;ocqjts>kmd*E~OFS;q6A9@n9#9Qn*G>;)Q`%Y9Uw`^j?W`c?WMw-$QtbL(4itZze)=S0swFy}dQkE5S!b3N|$HUE&fUZ1&k z!r9dGx+CX4=N$XwAh$i^8ytNyw>1ciK9n=j^SUFK&xPDBdXJoAbWYBh>x%2=kT?Sy z`9jZq(zz6z`BH-0AEDRpi6a~RUOC4gw+hE~#c?ll zD>cUR4Lx#A^t|rKxz9X5uB|!8??K}ngU?dmE3PY!^_knqBA@EH&#iC8xz9QGi`p3E zc6pz<-f)F)aP-X_*^1LXyza=k&&=1gfKl(EAp`Clt?{49^)>&HxL%(*)>fQNJ+C`* z?sLvX4&JrjzAwJP?FF};aN4ie*;5~(*Y79&?&vuz%nbUk; nchqBj=J;C*-z(>1=gT);^k*f&<3e_P@RYyjOYa_iifYg5T%4`d0Zp zuE68%^Dhnh71TS#_xEP_)2$%E^RplCd%X9p4DRJSkT}X@t}XP)ndmWHaqg3R3B&hW z@n@Cu?KoDwo}-kxtml0foLm?0n>qJM^W6sbfa92NyA7`QEjapQZoA-2^t^BF`x^JR z%vVH@^Qmue^vPV2`IOs4&-=!F!MV@4F4PL&e{me#6OPrF`HEwI=2+u%pn9yG=Iebk z=RW88RPT;H-d}Q6*2{dwv6i`tTFe)EK+ zZbOahP(5-cdfqp4?sJ}xZ0OyobIeyok9{kSwanFp9=RfV-naG*j&pF^3$6izGd^Qq zj<7?|j(o>{Lnvq55@(oY#qGtf3-$W8;0n$u=eg_gyZelYn)>d&Z|2BmF6L7`*BtYC z->m08gJT|g{O&&RJ{{*$LHf53~sf^|Iawy5%GIOrspVSj^`#0wIyfCcfq;OHAlaO zAlJ&^dY|Y8=le6)E;wCB^jENaP9xuOoH^Eld&Y4?ux7{oC`Zm(MZ|O8aLjX=Gv4zB zx669HZ{~LNy~aI2yLDzf_tava`Ub~yTUy5^^C_q6@V+r$aPBki2f6;NV$bVxm-%Y) zU2*KsT$TAk&wb9mD~|J(Uw*=;CjFKdf%+) zKFJq8T*u^V*IWS~^XYoCZ*Y0N voxel affine = np.dot(np.linalg.inv(self.header[Field.VOXEL_TO_RASMM]), affine) @@ -330,6 +322,13 @@ def write(self, tractogram): M = nib.orientations.inv_ornt_aff(ornt, self.header[Field.DIMENSIONS]) affine = np.dot(M, affine) + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas `Tractogram` streamlines assume (0,0,0) is the + # center of the voxel. Thus, streamlines are shifted of half a voxel. + offset = np.eye(4) + offset[:-1, -1] += 0.5 + affine = np.dot(offset, affine) + # Finally send the streamlines in mm space. # voxel -> voxelmm scale = np.eye(4) @@ -580,6 +579,13 @@ def load(cls, fileobj, lazy_load=False): scale[range(3), range(3)] /= trk_reader.header[Field.VOXEL_SIZES] affine = np.dot(scale, affine) + # TrackVis considers coordinate (0,0,0) to be the corner of the voxel + # whereas streamlines returned assume (0,0,0) to be the center of the + # voxel. Thus, streamlines are shifted of half a voxel. + offset = np.eye(4) + offset[:-1, -1] -= 0.5 + affine = np.dot(offset, affine) + # If the voxel order implied by the affine does not match the voxel # order in the TRK header, change the orientation. # voxel (header) -> voxel (affine) @@ -595,13 +601,6 @@ def load(cls, fileobj, lazy_load=False): # voxel -> rasmm affine = np.dot(trk_reader.header[Field.VOXEL_TO_RASMM], affine) - # TrackVis considers coordinate (0,0,0) to be the corner of the voxel - # whereas streamlines returned assume (0,0,0) to be the center of the - # voxel. Thus, streamlines are shifted of half a voxel. - offset = np.eye(4) - offset[:-1, -1] -= np.array(trk_reader.header[Field.VOXEL_SIZES])/2. - affine = np.dot(offset, affine) - # Find scalars and properties name data_per_point_slice = {} From 9670c43fba9b9c34bab4ec1f2a3c0dbdeb9f2b6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 29 Nov 2015 01:04:25 -0500 Subject: [PATCH 042/135] Added the mask used by the standard test object --- nibabel/streamlines/tests/data/standard.nii.gz | Bin 0 -> 143 bytes 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 nibabel/streamlines/tests/data/standard.nii.gz diff --git a/nibabel/streamlines/tests/data/standard.nii.gz b/nibabel/streamlines/tests/data/standard.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..98bb31a77860a668b865d16f4605b8612eba512c GIT binary patch literal 143 zcmV;A0C4{wiwFoOlv-8-|8sO Date: Sun, 29 Nov 2015 02:05:34 -0500 Subject: [PATCH 043/135] Added support for Python 3 --- nibabel/streamlines/compact_list.py | 2 +- nibabel/streamlines/tests/test_compact_list.py | 12 ++++++------ nibabel/streamlines/trk.py | 12 +++++++----- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index fb98536bab..785e363490 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -110,7 +110,7 @@ def extend(self, elements): lengths = elements._lengths else: self._data = np.concatenate([self._data] + list(elements), axis=0) - lengths = map(len, elements) + lengths = list(map(len, elements)) idx = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 self._lengths.extend(lengths) diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 71dac112b7..34b1ceb65d 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -17,7 +17,7 @@ class TestCompactList(unittest.TestCase): def setUp(self): rng = np.random.RandomState(42) self.data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - self.lengths = map(len, self.data) + self.lengths = list(map(len, self.data)) self.clist = CompactList(self.data) def test_creating_empty_compactlist(self): @@ -31,7 +31,7 @@ def test_creating_empty_compactlist(self): def test_creating_compactlist_from_list(self): rng = np.random.RandomState(42) data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = map(len, data) + lengths = list(map(len, data)) clist = CompactList(data) assert_equal(len(clist), len(data)) @@ -54,7 +54,7 @@ def test_creating_compactlist_from_list(self): def test_creating_compactlist_from_generator(self): rng = np.random.RandomState(42) data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = map(len, data) + lengths = list(map(len, data)) gen = (e for e in data) clist = CompactList(gen) @@ -78,7 +78,7 @@ def test_creating_compactlist_from_generator(self): def test_creating_compactlist_from_compact_list(self): rng = np.random.RandomState(42) data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = map(len, data) + lengths = list(map(len, data)) clist = CompactList(data) clist2 = CompactList(clist) @@ -152,7 +152,7 @@ def test_compactlist_extend(self): rng = np.random.RandomState(1234) shape = self.clist.shape new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(5)] - lengths = map(len, new_data) + lengths = list(map(len, new_data)) clist.extend(new_data) assert_equal(len(clist), len(self.clist)+len(new_data)) assert_array_equal(clist._offsets[-len(new_data):], @@ -187,7 +187,7 @@ def test_compactlist_getitem(self): assert_array_equal(self.clist[i], e) # Get multiple items (this will create a view). - clist_view = self.clist[range(len(self.clist))] + clist_view = self.clist[list(range(len(self.clist)))] assert_true(clist_view is not self.clist) assert_true(clist_view._data is self.clist._data) assert_true(clist_view._offsets is not self.clist._offsets) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index cdff6d6624..251d7681c8 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -13,6 +13,7 @@ from nibabel.affines import apply_affine from nibabel.openers import Opener +from nibabel.py3k import asbytes, asstr from nibabel.volumeutils import (native_code, swapped_code) from .compact_list import CompactList @@ -276,7 +277,7 @@ def write(self, tractogram): property_name = k if nb_values > 1: # Use the last to bytes of the name to store the nb of values associated to this data_for_streamline. - property_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(nb_values, dtype=np.int8).tostring() + property_name = asbytes(k[:18].ljust(18, '\x00')) + b'\x00' + np.array(nb_values, dtype=np.int8).tostring() self.header['property_name'][i] = property_name @@ -298,7 +299,7 @@ def write(self, tractogram): scalar_name = k if nb_values > 1: # Use the last to bytes of the name to store the nb of values associated to this data_for_streamline. - scalar_name = k[:18].ljust(18, '\x00') + '\x00' + np.array(nb_values, dtype=np.int8).tostring() + scalar_name = asbytes(k[:18].ljust(18, '\x00')) + b'\x00' + np.array(nb_values, dtype=np.int8).tostring() self.header['scalar_name'][i] = scalar_name @@ -314,7 +315,7 @@ def write(self, tractogram): # If the voxel order implied by the affine does not match the voxel # order in the TRK header, change the orientation. # voxel (affine) -> voxel (header) - header_ornt = self.header[Field.VOXEL_ORDER] + header_ornt = asstr(self.header[Field.VOXEL_ORDER]) affine_ornt = "".join(nib.orientations.aff2axcodes(self.header[Field.VOXEL_TO_RASMM])) header_ornt = nib.orientations.axcodes2ornt(header_ornt) affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) @@ -589,7 +590,7 @@ def load(cls, fileobj, lazy_load=False): # If the voxel order implied by the affine does not match the voxel # order in the TRK header, change the orientation. # voxel (header) -> voxel (affine) - header_ornt = trk_reader.header[Field.VOXEL_ORDER] + header_ornt = asstr(trk_reader.header[Field.VOXEL_ORDER]) affine_ornt = "".join(nib.orientations.aff2axcodes(trk_reader.header[Field.VOXEL_TO_RASMM])) header_ornt = nib.orientations.axcodes2ornt(header_ornt) affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) @@ -601,12 +602,12 @@ def load(cls, fileobj, lazy_load=False): # voxel -> rasmm affine = np.dot(trk_reader.header[Field.VOXEL_TO_RASMM], affine) - # Find scalars and properties name data_per_point_slice = {} if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: cpt = 0 for scalar_name in trk_reader.header['scalar_name']: + scalar_name = asstr(scalar_name) if len(scalar_name) == 0: continue @@ -626,6 +627,7 @@ def load(cls, fileobj, lazy_load=False): if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: cpt = 0 for property_name in trk_reader.header['property_name']: + property_name = asstr(property_name) if len(property_name) == 0: continue From 1fa32e19512195c9bdafac76c66bd8cabd23ba4a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 29 Nov 2015 02:25:50 -0500 Subject: [PATCH 044/135] Revert changes made to __init__.py --- nibabel/__init__.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/nibabel/__init__.py b/nibabel/__init__.py index 163bcd9e15..4d8791d7d9 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -23,9 +23,9 @@ img3 = nib.load('spm_file.img') data = img1.get_data() - affine = img1.get_affine() + affine = img1.affine - print img1 + print(img1) nib.save(img1, 'my_file_copy.nii.gz') @@ -63,7 +63,6 @@ apply_orientation, aff2axcodes) from .imageclasses import class_map, ext_map, all_image_classes from . import trackvis -from .streamlines import Tractogram from . import mriutils from . import viewers From d082756b70ee0864f9f9a0150f77044772ace685 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 29 Nov 2015 02:57:15 -0500 Subject: [PATCH 045/135] Removed streamlines benchmark for now. --- nibabel/benchmarks/bench_streamlines.py | 136 ------------------------ 1 file changed, 136 deletions(-) delete mode 100644 nibabel/benchmarks/bench_streamlines.py diff --git a/nibabel/benchmarks/bench_streamlines.py b/nibabel/benchmarks/bench_streamlines.py deleted file mode 100644 index 3d0da178c1..0000000000 --- a/nibabel/benchmarks/bench_streamlines.py +++ /dev/null @@ -1,136 +0,0 @@ -""" Benchmarks for load and save of streamlines - -Run benchmarks with:: - - import nibabel as nib - nib.bench() - -If you have doctests enabled by default in nose (with a noserc file or -environment variable), and you have a numpy version <= 1.6.1, this will also run -the doctests, let's hope they pass. - -Run this benchmark with: - - nosetests -s --match '(?:^|[\\b_\\.//-])[Bb]ench' /path/to/bench_streamlines.py -""" -from __future__ import division, print_function - -import os -import numpy as np - -from nibabel.externals.six import BytesIO -from nibabel.externals.six.moves import zip - -from nibabel.testing import assert_arrays_equal - -from numpy.testing import assert_array_equal -from nibabel.streamlines.base_format import Streamlines -from nibabel.streamlines import TrkFile - -import nibabel as nib -import nibabel.trackvis as tv - -from numpy.testing import measure - - -def bench_load_trk(): - NB_STREAMLINES = 1000 - NB_POINTS = 1000 - points = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)] - repeat = 20 - - trk_file = BytesIO() - #trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES)) - #tv.write(trk_file, trk) - streamlines = Streamlines(points) - TrkFile.save(streamlines, trk_file) - - # from pycallgraph import PyCallGraph - # from pycallgraph.output import GraphvizOutput - - # with PyCallGraph(output=GraphvizOutput()): - # #nib.streamlines.load(trk_file, ref=None, lazy_load=False) - - mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, ref=None, lazy_load=False)', repeat) - print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) - - mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file, points_space="voxel")', repeat) - print("Old: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) - print("Speedup of %2f" % (mtime_old/mtime_new)) - - # Points and scalars - scalars = [np.random.rand(NB_POINTS, 10).astype('float32') for i in range(NB_STREAMLINES)] - - trk_file = BytesIO() - #trk = list(zip(points, scalars, [None]*NB_STREAMLINES)) - #tv.write(trk_file, trk) - streamlines = Streamlines(points, scalars) - TrkFile.save(streamlines, trk_file) - - mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, ref=None, lazy_load=False)', repeat) - print("New: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new)) - - mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file, points_space="voxel")', repeat) - print("Old: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old)) - print("Speedup of %2f" % (mtime_old/mtime_new)) - - -def bench_save_trk(): - NB_STREAMLINES = 100 - NB_POINTS = 1000 - points = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)] - repeat = 10 - - # Only points - streamlines = Streamlines(points) - trk_file_new = BytesIO() - - mtime_new = measure('trk_file_new.seek(0, os.SEEK_SET); TrkFile.save(streamlines, trk_file_new)', repeat) - print("\nNew: Saved %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) - - trk_file_old = BytesIO() - trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES)) - mtime_old = measure('trk_file_old.seek(0, os.SEEK_SET); tv.write(trk_file_old, trk)', repeat) - print("Old: Saved %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) - print("Speedup of %2f" % (mtime_old/mtime_new)) - - trk_file_new.seek(0, os.SEEK_SET) - trk_file_old.seek(0, os.SEEK_SET) - streams, hdr = tv.read(trk_file_old) - - for pts, A in zip(points, streams): - assert_array_equal(pts, A[0]) - - trk = nib.streamlines.load(trk_file_new, ref=None, lazy_load=False) - assert_arrays_equal(points, trk.points) - - # Points and scalars - scalars = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)] - streamlines = Streamlines(points, scalars=scalars) - trk_file_new = BytesIO() - - mtime_new = measure('trk_file_new.seek(0, os.SEEK_SET); TrkFile.save(streamlines, trk_file_new)', repeat) - print("New: Saved %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new)) - - trk_file_old = BytesIO() - trk = list(zip(points, scalars, [None]*NB_STREAMLINES)) - mtime_old = measure('trk_file_old.seek(0, os.SEEK_SET); tv.write(trk_file_old, trk)', repeat) - print("Old: Saved %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old)) - print("Speedup of %2f" % (mtime_old/mtime_new)) - - trk_file_new.seek(0, os.SEEK_SET) - trk_file_old.seek(0, os.SEEK_SET) - streams, hdr = tv.read(trk_file_old) - - for pts, scal, A in zip(points, scalars, streams): - assert_array_equal(pts, A[0]) - assert_array_equal(scal, A[1]) - - trk = nib.streamlines.load(trk_file_new, ref=None, lazy_load=False) - - assert_arrays_equal(points, trk.points) - assert_arrays_equal(scalars, trk.scalars) - - -if __name__ == '__main__': - bench_save_trk() From 4bc8f04381f9a3c546adcde77b6e365337337226 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 29 Nov 2015 14:55:00 -0500 Subject: [PATCH 046/135] Python2.6 compatibility fix. Thanks @effigies --- nibabel/streamlines/trk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 251d7681c8..ede5f81124 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -647,8 +647,8 @@ def load(cls, fileobj, lazy_load=False): if lazy_load: def _read(): for pts, scals, props in trk_reader: - data_for_points = {k: scals[:, v] for k, v in data_per_point_slice.items()} - data_for_streamline = {k: props[v] for k, v in data_per_streamline_slice.items()} + data_for_points = dict((k, scals[:, v]) for k, v in data_per_point_slice.items()) + data_for_streamline = dict((k, props[v]) for k, v in data_per_streamline_slice.items()) yield TractogramItem(pts, data_for_streamline, data_for_points) tractogram = LazyTractogram.create_from(_read) From c60b34824fe0a3bc24a8240d9908017f011d9d78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 29 Nov 2015 16:07:35 -0500 Subject: [PATCH 047/135] Added more unit tests to increase coverage. --- nibabel/streamlines/base_format.py | 12 --- nibabel/streamlines/compact_list.py | 16 +-- .../streamlines/tests/test_compact_list.py | 26 +++++ nibabel/streamlines/tests/test_tractogram.py | 98 ++++++++++++++++++- nibabel/streamlines/tests/test_trk.py | 3 +- nibabel/streamlines/tractogram.py | 10 +- nibabel/streamlines/tractogram_file.py | 12 +++ nibabel/streamlines/trk.py | 2 +- 8 files changed, 153 insertions(+), 26 deletions(-) delete mode 100644 nibabel/streamlines/base_format.py diff --git a/nibabel/streamlines/base_format.py b/nibabel/streamlines/base_format.py deleted file mode 100644 index e694e7e0c4..0000000000 --- a/nibabel/streamlines/base_format.py +++ /dev/null @@ -1,12 +0,0 @@ - - -class HeaderWarning(Warning): - pass - - -class HeaderError(Exception): - pass - - -class DataError(Exception): - pass diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 785e363490..ffe2e9be9d 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -5,6 +5,9 @@ class CompactList(object): """ Class for compacting list of ndarrays with matching shape except for the first dimension. """ + + BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. + def __init__(self, iterable=None): """ Parameters @@ -31,20 +34,19 @@ def __init__(self, iterable=None): elif iterable is not None: # Initialize the `CompactList` object from iterable's item. - BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. - offset = 0 for i, e in enumerate(iterable): e = np.asarray(e) if i == 0: - self._data = np.empty((BUFFER_SIZE,) + e.shape[1:], - dtype=e.dtype) + new_shape = (CompactList.BUFFER_SIZE,) + e.shape[1:] + self._data = np.empty(new_shape, dtype=e.dtype) end = offset + len(e) if end >= len(self._data): - # Resize is needed (at least `len(e)` items will be added). - self._data.resize((len(self._data) + len(e)+BUFFER_SIZE,) - + self.shape) + # Resize needed, adding `len(e)` new items plus some buffer. + nb_points = len(self._data) + nb_points += len(e) + CompactList.BUFFER_SIZE + self._data.resize((nb_points,) + self.shape) self._offsets.append(offset) self._lengths.append(len(e)) diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 34b1ceb65d..7a1c25000f 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -51,6 +51,20 @@ def test_creating_compactlist_from_list(self): assert_true(clist._data is None) assert_true(clist.shape is None) + # Force CompactList constructor to use buffering. + old_buffer_size = CompactList.BUFFER_SIZE + CompactList.BUFFER_SIZE = 1 + clist = CompactList(data) + assert_equal(len(clist), len(data)) + assert_equal(len(clist._offsets), len(data)) + assert_equal(len(clist._lengths), len(data)) + assert_equal(clist._data.shape[0], sum(lengths)) + assert_equal(clist._data.shape[1], 3) + assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) + assert_equal(clist._lengths, lengths) + assert_equal(clist.shape, data[0].shape[1:]) + CompactList.BUFFER_SIZE = old_buffer_size + def test_creating_compactlist_from_generator(self): rng = np.random.RandomState(42) data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] @@ -95,6 +109,11 @@ def test_compactlist_iter(self): for e, d in zip(self.clist, self.data): assert_array_equal(e, d) + # Try iterate through a corrupted CompactList object. + clist = self.clist.copy() + clist._lengths = clist._lengths[::2] + assert_raises(ValueError, list, clist) + def test_compactlist_copy(self): clist = self.clist.copy() assert_array_equal(clist._data, self.clist._data) @@ -219,6 +238,13 @@ def test_compactlist_getitem(self): assert_array_equal(clist_view[1], self.clist[2]) assert_array_equal(clist_view[2], self.clist[4]) + # Test invalid indexing + assert_raises(TypeError, self.clist.__getitem__, 'abc') + + def test_compactlist_repr(self): + # Test that calling repr on a CompactList object is not falling. + repr(self.clist) + def test_save_and_load_compact_list(): diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 64b0e150ab..23e4440515 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -127,7 +127,7 @@ def test_tractogram_creation(self): assert_raises(ValueError, Tractogram, self.streamlines, data_per_point=data_per_point) - def test_tractogram_getter(self): + def test_tractogram_getitem(self): # Tractogram with only streamlines tractogram = Tractogram(streamlines=self.streamlines) @@ -200,6 +200,43 @@ def test_tractogram_add_new_data(self): assert_arrays_equal(r_tractogram.data_per_point['colors'], self.colors[::-1]) + def test_tractogram_copy(self): + # Create a tractogram with streamlines and other data. + tractogram1 = Tractogram( + self.streamlines, + data_per_streamline={'mean_curvature': self.mean_curvature, + 'mean_color': self.mean_color}, + data_per_point={'colors': self.colors}) + + # Create a copy of the tractogram. + tractogram2 = tractogram1.copy() + + # Check we copied the data and not simply created new references. + assert_true(tractogram1 is not tractogram2) + assert_true(tractogram1.streamlines is not tractogram2.streamlines) + assert_true(tractogram1.data_per_streamline + is not tractogram2.data_per_streamline) + assert_true(tractogram1.data_per_streamline['mean_curvature'] + is not tractogram2.data_per_streamline['mean_curvature']) + assert_true(tractogram1.data_per_streamline['mean_color'] + is not tractogram2.data_per_streamline['mean_color']) + assert_true(tractogram1.data_per_point + is not tractogram2.data_per_point) + assert_true(tractogram1.data_per_point['colors'] + is not tractogram2.data_per_point['colors']) + + # Check the data are the equivalent. + assert_true(isiterable(tractogram2)) + assert_equal(len(tractogram1), len(tractogram2)) + assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) + assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) + assert_arrays_equal(tractogram1.data_per_streamline['mean_curvature'], + tractogram2.data_per_streamline['mean_curvature']) + assert_arrays_equal(tractogram1.data_per_streamline['mean_color'], + tractogram2.data_per_streamline['mean_color']) + assert_arrays_equal(tractogram1.data_per_point['colors'], + tractogram2.data_per_point['colors']) + class TestLazyTractogram(unittest.TestCase): @@ -303,7 +340,10 @@ def _data_gen(): assert_arrays_equal(tractogram.data_per_point['colors'], self.colors) - def test_lazy_tractogram_indexing(self): + # Creating a LazyTractogram from not a corouting should raise an error. + assert_raises(TypeError, LazyTractogram.create_from, _data_gen()) + + def test_lazy_tractogram_getitem(self): streamlines = lambda: (x for x in self.streamlines) data_per_point = {"colors": self.colors_func} data_per_streamline = {'mean_curv': self.mean_curvature_func, @@ -382,3 +422,57 @@ def test_lazy_tractogram_apply_affine(self): assert_equal(len(tractogram), len(self.streamlines)) for s1, s2 in zip(tractogram.streamlines, self.streamlines): assert_array_almost_equal(s1, s2*scaling) + + def test_lazy_tractogram_copy(self): + # Create tractogram with streamlines and other data + streamlines = lambda: (x for x in self.streamlines) + data_per_point = {"colors": self.colors_func} + data_per_streamline = {'mean_curv': self.mean_curvature_func, + 'mean_color': self.mean_color_func} + + tractogram1 = LazyTractogram(streamlines, + data_per_streamline=data_per_streamline, + data_per_point=data_per_point) + assert_true(isiterable(tractogram1)) # Implicitly set _nb_streamlines. + + # Create a copy of the tractogram. + tractogram2 = tractogram1.copy() + + # Check we copied the data and not simply created new references. + assert_true(tractogram1 is not tractogram2) + + # When copying LazyTractogram, coroutines generating streamlines should + # be the same. + assert_true(tractogram1._streamlines is tractogram2._streamlines) + + # Copying LazyTractogram, creates new internal LazyDict objects, + # but coroutines contained in it should be the same. + assert_true(tractogram1._data_per_streamline + is not tractogram2._data_per_streamline) + assert_true(tractogram1.data_per_streamline.store['mean_curv'] + is tractogram2.data_per_streamline.store['mean_curv']) + assert_true(tractogram1.data_per_streamline.store['mean_color'] + is tractogram2.data_per_streamline.store['mean_color']) + assert_true(tractogram1._data_per_point + is not tractogram2._data_per_point) + assert_true(tractogram1.data_per_point.store['colors'] + is tractogram2.data_per_point.store['colors']) + + # The affine should be a copy. + assert_true(tractogram1._affine_to_apply + is not tractogram2._affine_to_apply) + assert_array_equal(tractogram1._affine_to_apply, + tractogram2._affine_to_apply) + + # Check the data are the equivalent. + assert_equal(tractogram1._nb_streamlines, tractogram2._nb_streamlines) + assert_true(isiterable(tractogram2)) + assert_equal(len(tractogram1), len(tractogram2)) + assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) + assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) + assert_arrays_equal(tractogram1.data_per_streamline['mean_curv'], + tractogram2.data_per_streamline['mean_curv']) + assert_arrays_equal(tractogram1.data_per_streamline['mean_color'], + tractogram2.data_per_streamline['mean_color']) + assert_arrays_equal(tractogram1.data_per_point['colors'], + tractogram2.data_per_point['colors']) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 80853ea80a..ba3f3d3291 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -9,9 +9,8 @@ from nose.tools import assert_equal, assert_raises, assert_true from .test_tractogram import assert_tractogram_equal -from .. import base_format from ..tractogram import Tractogram, LazyTractogram -from ..base_format import DataError, HeaderError, HeaderWarning#, UsageWarning +from ..tractogram_file import DataError, HeaderError, HeaderWarning from .. import trk as trk_module from ..trk import TrkFile, header_2_dtype diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 2ae433fd4b..29fb975eee 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -202,7 +202,12 @@ class LazyDict(collections.MutableMapping): def __init__(self, *args, **kwargs): self.store = dict() - self.update(dict(*args, **kwargs)) # Use update to set keys. + + # Use update to set keys. + if len(args) == 1 and isinstance(args[0], LazyTractogram.LazyDict): + self.update(dict(args[0].store.items())) + else: + self.update(dict(*args, **kwargs)) def __getitem__(self, key): return self.store[key]() @@ -377,8 +382,9 @@ def copy(self): tractogram = LazyTractogram(self._streamlines, self._data_per_streamline, self._data_per_point) - tractogram.nb_streamlines = self.nb_streamlines + tractogram._nb_streamlines = self._nb_streamlines tractogram._data = self._data + tractogram._affine_to_apply = self._affine_to_apply.copy() return tractogram def apply_affine(self, affine): diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index feea0e9822..8bb1aa41ea 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -3,6 +3,18 @@ from .header import TractogramHeader +class HeaderWarning(Warning): + pass + + +class HeaderError(Exception): + pass + + +class DataError(Exception): + pass + + class abstractclassmethod(classmethod): __isabstractmethod__ = True diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index ede5f81124..9b043ff821 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -18,7 +18,7 @@ from .compact_list import CompactList from .tractogram_file import TractogramFile -from .base_format import DataError, HeaderError, HeaderWarning +from .tractogram_file import DataError, HeaderError, HeaderWarning from .tractogram import TractogramItem, Tractogram, LazyTractogram from .header import Field From d4b05f969afe4fef096362a2179c72489a513ac3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 1 Dec 2015 14:39:50 -0500 Subject: [PATCH 048/135] Changed function name 'isiterable' to 'check_iteration' --- nibabel/streamlines/tests/test_streamlines.py | 2 +- nibabel/streamlines/tests/test_tractogram.py | 28 +++++++++---------- nibabel/streamlines/tests/test_trk.py | 2 +- nibabel/testing/__init__.py | 10 ------- 4 files changed, 16 insertions(+), 26 deletions(-) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 645b3ac4cb..f76e8ab581 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -9,7 +9,7 @@ from nibabel.externals.six import BytesIO from nibabel.testing import clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal, isiterable +from nibabel.testing import assert_arrays_equal, check_iteration from nose.tools import assert_equal, assert_raises, assert_true, assert_false from .test_tractogram import assert_tractogram_equal diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 23e4440515..953765bc9b 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -2,7 +2,7 @@ import numpy as np import warnings -from nibabel.testing import assert_arrays_equal, isiterable +from nibabel.testing import assert_arrays_equal, check_iteration from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal, assert_array_almost_equal @@ -14,7 +14,7 @@ def assert_tractogram_equal(t1, t2): - assert_true(isiterable(t1)) + assert_true(check_iteration(t1)) assert_equal(len(t1), len(t2)) assert_arrays_equal(t1.streamlines, t2.streamlines) @@ -81,7 +81,7 @@ def test_tractogram_creation(self): assert_arrays_equal(tractogram.streamlines, []) assert_equal(tractogram.data_per_streamline, {}) assert_equal(tractogram.data_per_point, {}) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) # Create a tractogram with only streamlines tractogram = Tractogram(streamlines=self.streamlines) @@ -89,7 +89,7 @@ def test_tractogram_creation(self): assert_arrays_equal(tractogram.streamlines, self.streamlines) assert_equal(tractogram.data_per_streamline, {}) assert_equal(tractogram.data_per_point, {}) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) # Create a tractogram with streamlines and other data. tractogram = Tractogram( @@ -107,7 +107,7 @@ def test_tractogram_creation(self): assert_arrays_equal(tractogram.data_per_point['colors'], self.colors) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) # Inconsistent number of scalars between streamlines wrong_data = [[(1, 0, 0)]*1, @@ -226,7 +226,7 @@ def test_tractogram_copy(self): is not tractogram2.data_per_point['colors']) # Check the data are the equivalent. - assert_true(isiterable(tractogram2)) + assert_true(check_iteration(tractogram2)) assert_equal(len(tractogram1), len(tractogram2)) assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) @@ -280,7 +280,7 @@ def test_lazy_tractogram_creation(self): # Empty `LazyTractogram` tractogram = LazyTractogram() - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) assert_equal(len(tractogram), 0) assert_arrays_equal(tractogram.streamlines, []) assert_equal(tractogram.data_per_point, {}) @@ -296,7 +296,7 @@ def test_lazy_tractogram_creation(self): data_per_streamline=data_per_streamline, data_per_point=data_per_point) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) assert_equal(len(tractogram), self.nb_streamlines) # Coroutines get re-called and creates new iterators. @@ -314,7 +314,7 @@ def test_lazy_tractogram_create_from(self): _empty_data_gen = lambda: iter([]) tractogram = LazyTractogram.create_from(_empty_data_gen) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) assert_equal(len(tractogram), 0) assert_arrays_equal(tractogram.streamlines, []) assert_equal(tractogram.data_per_point, {}) @@ -330,7 +330,7 @@ def _data_gen(): yield TractogramItem(d[0], data_for_streamline, data_for_points) tractogram = LazyTractogram.create_from(_data_gen) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) assert_equal(len(tractogram), self.nb_streamlines) assert_arrays_equal(tractogram.streamlines, self.streamlines) assert_arrays_equal(tractogram.data_per_streamline['mean_curv'], @@ -397,7 +397,7 @@ def test_lazy_tractogram_len(self): data_per_point=data_per_point) assert_true(tractogram._nb_streamlines is None) - isiterable(tractogram) # Force to iterate through all streamlines. + check_iteration(tractogram) # Force iteration through tractogram. assert_equal(tractogram._nb_streamlines, len(self.streamlines)) # This should *not* produce a warning. assert_equal(len(tractogram), len(self.streamlines)) @@ -418,7 +418,7 @@ def test_lazy_tractogram_apply_affine(self): data_per_point=data_per_point) tractogram.apply_affine(affine) - assert_true(isiterable(tractogram)) + assert_true(check_iteration(tractogram)) assert_equal(len(tractogram), len(self.streamlines)) for s1, s2 in zip(tractogram.streamlines, self.streamlines): assert_array_almost_equal(s1, s2*scaling) @@ -433,7 +433,7 @@ def test_lazy_tractogram_copy(self): tractogram1 = LazyTractogram(streamlines, data_per_streamline=data_per_streamline, data_per_point=data_per_point) - assert_true(isiterable(tractogram1)) # Implicitly set _nb_streamlines. + assert_true(check_iteration(tractogram1)) # Implicitly set _nb_streamlines. # Create a copy of the tractogram. tractogram2 = tractogram1.copy() @@ -466,7 +466,7 @@ def test_lazy_tractogram_copy(self): # Check the data are the equivalent. assert_equal(tractogram1._nb_streamlines, tractogram2._nb_streamlines) - assert_true(isiterable(tractogram2)) + assert_true(check_iteration(tractogram2)) assert_equal(len(tractogram1), len(tractogram2)) assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index ba3f3d3291..f689e53ef8 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -5,7 +5,7 @@ from nibabel.externals.six import BytesIO from nibabel.testing import suppress_warnings, clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal, isiterable +from nibabel.testing import assert_arrays_equal, check_iteration from nose.tools import assert_equal, assert_raises, assert_true from .test_tractogram import assert_tractogram_equal diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index edf6002394..e5332014de 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -179,17 +179,7 @@ class suppress_warnings(error_warnings): class catch_warn_reset(clear_and_catch_warnings): - def __init__(self, *args, **kwargs): warnings.warn('catch_warn_reset is deprecated and will be removed in ' 'nibabel v3.0; use nibabel.testing.clear_and_catch_warnings.', FutureWarning) - - -EXTRA_SET = os.environ.get('NIPY_EXTRA_TESTS', '').split(',') - - -def runif_extra_has(test_str): - """Decorator checks to see if NIPY_EXTRA_TESTS env var contains test_str""" - return skipif(test_str not in EXTRA_SET, - "Skip {0} tests.".format(test_str)) From 476a93ad9bdc2f4f84782e9bb255113064c66983 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 1 Dec 2015 14:55:53 -0500 Subject: [PATCH 049/135] Support upper case file extension --- nibabel/streamlines/__init__.py | 7 +-- nibabel/streamlines/tests/test_streamlines.py | 43 +++++++++++-------- 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index abf9ca27f3..6cb4bdd87a 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,3 +1,6 @@ +import os +from ..externals.six import string_types + from .header import TractogramHeader from .compact_list import CompactList from .tractogram import Tractogram, LazyTractogram @@ -54,11 +57,9 @@ def detect_format(fileobj): except IOError: pass - import os - from ..externals.six import string_types if isinstance(fileobj, string_types): _, ext = os.path.splitext(fileobj) - return FORMATS.get(ext, None) + return FORMATS.get(ext.lower()) return None diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index f76e8ab581..c21c688989 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -28,21 +28,21 @@ def test_is_supported(): assert_false(nib.streamlines.is_supported("")) # Valid file without extension - for tractogram_file in nib.streamlines.FORMATS.values(): + for tfile_cls in nib.streamlines.FORMATS.values(): f = BytesIO() - f.write(tractogram_file.get_magic_number()) + f.write(tfile_cls.get_magic_number()) f.seek(0, os.SEEK_SET) assert_true(nib.streamlines.is_supported(f)) # Wrong extension but right magic number - for tractogram_file in nib.streamlines.FORMATS.values(): + for tfile_cls in nib.streamlines.FORMATS.values(): with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: - f.write(tractogram_file.get_magic_number()) + f.write(tfile_cls.get_magic_number()) f.seek(0, os.SEEK_SET) assert_true(nib.streamlines.is_supported(f)) # Good extension but wrong magic number - for ext, tractogram_file in nib.streamlines.FORMATS.items(): + for ext, tfile_cls in nib.streamlines.FORMATS.items(): with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: f.write(b"pass") f.seek(0, os.SEEK_SET) @@ -53,7 +53,7 @@ def test_is_supported(): assert_false(nib.streamlines.is_supported(f)) # Good extension, string only - for ext, tractogram_file in nib.streamlines.FORMATS.items(): + for ext, tfile_cls in nib.streamlines.FORMATS.items(): f = "my_tractogram" + ext assert_true(nib.streamlines.is_supported(f)) @@ -61,38 +61,43 @@ def test_is_supported(): def test_detect_format(): # Emtpy file/string f = BytesIO() - assert_equal(nib.streamlines.detect_format(f), None) - assert_equal(nib.streamlines.detect_format(""), None) + assert_true(nib.streamlines.detect_format(f) is None) + assert_true(nib.streamlines.detect_format("") is None) # Valid file without extension - for tractogram_file in nib.streamlines.FORMATS.values(): + for tfile_cls in nib.streamlines.FORMATS.values(): f = BytesIO() - f.write(tractogram_file.get_magic_number()) + f.write(tfile_cls.get_magic_number()) f.seek(0, os.SEEK_SET) - assert_equal(nib.streamlines.detect_format(f), tractogram_file) + assert_true(nib.streamlines.detect_format(f) is tfile_cls) # Wrong extension but right magic number - for tractogram_file in nib.streamlines.FORMATS.values(): + for tfile_cls in nib.streamlines.FORMATS.values(): with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: - f.write(tractogram_file.get_magic_number()) + f.write(tfile_cls.get_magic_number()) f.seek(0, os.SEEK_SET) - assert_equal(nib.streamlines.detect_format(f), tractogram_file) + assert_true(nib.streamlines.detect_format(f) is tfile_cls) # Good extension but wrong magic number - for ext, tractogram_file in nib.streamlines.FORMATS.items(): + for ext, tfile_cls in nib.streamlines.FORMATS.items(): with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: f.write(b"pass") f.seek(0, os.SEEK_SET) - assert_equal(nib.streamlines.detect_format(f), None) + assert_true(nib.streamlines.detect_format(f) is None) # Wrong extension, string only f = "my_tractogram.asd" - assert_equal(nib.streamlines.detect_format(f), None) + assert_true(nib.streamlines.detect_format(f) is None) # Good extension, string only - for ext, tractogram_file in nib.streamlines.FORMATS.items(): + for ext, tfile_cls in nib.streamlines.FORMATS.items(): f = "my_tractogram" + ext - assert_equal(nib.streamlines.detect_format(f), tractogram_file) + assert_equal(nib.streamlines.detect_format(f), tfile_cls) + + # Extension should not be case-sensitive. + for ext, tfile_cls in nib.streamlines.FORMATS.items(): + f = "my_tractogram" + ext.upper() + assert_true(nib.streamlines.detect_format(f) is tfile_cls) class TestLoadSave(unittest.TestCase): From cac1f17614f275c5eab86959b522ef302c0d1976 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 1 Dec 2015 15:31:00 -0500 Subject: [PATCH 050/135] Fixed typo --- nibabel/streamlines/compact_list.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index ffe2e9be9d..889c48faea 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -148,7 +148,7 @@ def __getitem__(self, idx): Returns ------- ndarray object(s) - When `idx` is a int, returns a single ndarray. + When `idx` is an int, returns a single ndarray. When `idx` is either a slice or a list, returns a list of ndarrays. """ if isinstance(idx, int) or isinstance(idx, np.integer): @@ -178,7 +178,7 @@ def __getitem__(self, idx): for i, take_it in enumerate(idx) if take_it] return clist - raise TypeError("Index must be a int or a slice! Not " + str(type(idx))) + raise TypeError("Index must be an int or a slice! Not " + str(type(idx))) def __iter__(self): if len(self._lengths) != len(self._offsets): From b6591e4e9459d08eb79936f122cbbecb3ce5f768 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 1 Dec 2015 15:52:54 -0500 Subject: [PATCH 051/135] Use isinstance instead of type() whenever possible --- nibabel/streamlines/compact_list.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 889c48faea..35ca7c98ae 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -155,14 +155,14 @@ def __getitem__(self, idx): start = self._offsets[idx] return self._data[start:start+self._lengths[idx]] - elif type(idx) is slice: + elif isinstance(idx, slice): clist = CompactList() clist._data = self._data clist._offsets = self._offsets[idx] clist._lengths = self._lengths[idx] return clist - elif type(idx) is list: + elif isinstance(idx, list): clist = CompactList() clist._data = self._data clist._offsets = [self._offsets[i] for i in idx] @@ -178,7 +178,8 @@ def __getitem__(self, idx): for i, take_it in enumerate(idx) if take_it] return clist - raise TypeError("Index must be an int or a slice! Not " + str(type(idx))) + raise TypeError("Index must be either an int, a slice, a list of int" + " or a ndarray of bool! Not " + str(type(idx))) def __iter__(self): if len(self._lengths) != len(self._offsets): From f1312115d1e75fe34df07299561a2cf213d32d25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 1 Dec 2015 20:16:09 -0500 Subject: [PATCH 052/135] Added the module streamlines to nibabel --- nibabel/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nibabel/__init__.py b/nibabel/__init__.py index 4d8791d7d9..91ccace44b 100644 --- a/nibabel/__init__.py +++ b/nibabel/__init__.py @@ -64,6 +64,7 @@ from .imageclasses import class_map, ext_map, all_image_classes from . import trackvis from . import mriutils +from . import streamlines from . import viewers # be friendly on systems with ancient numpy -- no tests, but at least From 49a046288a63ba950ec13e8b04000a757c9222f9 Mon Sep 17 00:00:00 2001 From: Marc-Alexandre Cote Date: Wed, 2 Dec 2015 17:58:29 -0500 Subject: [PATCH 053/135] BF: CompactList.extend with a sliced CompactList was not doing the right thing --- nibabel/streamlines/compact_list.py | 19 ++++++--- .../streamlines/tests/test_compact_list.py | 41 +++++++++++++------ 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 35ca7c98ae..8fb761f312 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -107,16 +107,23 @@ def extend(self, elements): elem = np.asarray(elements[0]) self._data = np.zeros((0, elem.shape[1]), dtype=elem.dtype) + next_offset = self._data.shape[0] + if isinstance(elements, CompactList): - self._data = np.concatenate([self._data, elements._data], axis=0) - lengths = elements._lengths + self._data.resize((self._data.shape[0]+sum(elements._lengths), + self._data.shape[1])) + + for offset, length in zip(elements._offsets, elements._lengths): + self._offsets.append(next_offset) + self._lengths.append(length) + self._data[next_offset:next_offset+length] = elements._data[offset:offset+length] + next_offset += length + else: self._data = np.concatenate([self._data] + list(elements), axis=0) lengths = list(map(len, elements)) - - idx = self._offsets[-1] + self._lengths[-1] if len(self) > 0 else 0 - self._lengths.extend(lengths) - self._offsets.extend(np.cumsum([idx] + lengths).tolist()[:-1]) + self._lengths.extend(lengths) + self._offsets.extend(np.cumsum([next_offset] + lengths).tolist()[:-1]) def copy(self): """ Creates a copy of this ``CompactList`` object. """ diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 7a1c25000f..188102ce85 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -4,6 +4,7 @@ import numpy as np from nose.tools import assert_equal, assert_raises, assert_true +from nibabel.testing import assert_arrays_equal from numpy.testing import assert_array_equal from nibabel.externals.six.moves import zip, zip_longest @@ -170,7 +171,7 @@ def test_compactlist_extend(self): rng = np.random.RandomState(1234) shape = self.clist.shape - new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(5)] + new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(10)] lengths = list(map(len, new_data)) clist.extend(new_data) assert_equal(len(clist), len(self.clist)+len(new_data)) @@ -183,22 +184,38 @@ def test_compactlist_extend(self): # Extend with another `CompactList` object. clist = self.clist.copy() - new_data = CompactList(new_data) - clist.extend(new_data) - assert_equal(len(clist), len(self.clist)+len(new_data)) - assert_array_equal(clist._offsets[-len(new_data):], + new_clist = CompactList(new_data) + clist.extend(new_clist) + assert_equal(len(clist), len(self.clist)+len(new_clist)) + assert_array_equal(clist._offsets[-len(new_clist):], len(self.clist._data) + np.cumsum([0] + lengths[:-1])) - assert_equal(clist._lengths[-len(new_data):], lengths) - assert_array_equal(clist._data[-sum(lengths):], new_data._data) + assert_equal(clist._lengths[-len(new_clist):], lengths) + assert_array_equal(clist._data[-sum(lengths):], new_clist._data) + + # Extend with another `CompactList` object that is a view (e.g. been sliced). + # Need to make sure we extend only the data we need. + clist = self.clist.copy() + new_clist = CompactList(new_data)[::2] + clist.extend(new_clist) + assert_equal(len(clist), len(self.clist)+len(new_clist)) + assert_equal(len(clist._data), len(self.clist._data)+sum(new_clist._lengths)) + assert_array_equal(clist._offsets[-len(new_clist):], + len(self.clist._data) + np.cumsum([0] + new_clist._lengths[:-1])) + + assert_equal(clist._lengths[-len(new_clist):], lengths[::2]) + assert_array_equal(clist._data[-sum(new_clist._lengths):], new_clist.copy()._data) + assert_arrays_equal(clist[-len(new_clist):], new_clist) # Test extending an empty CompactList clist = CompactList() - clist.extend(new_data) - assert_equal(len(clist), len(new_data)) - assert_array_equal(clist._offsets, new_data._offsets) - assert_array_equal(clist._lengths, new_data._lengths) - assert_array_equal(clist._data, new_data._data) + new_clist = CompactList(new_data) + clist.extend(new_clist) + assert_equal(len(clist), len(new_clist)) + assert_array_equal(clist._offsets, new_clist._offsets) + assert_array_equal(clist._lengths, new_clist._lengths) + assert_array_equal(clist._data, new_clist._data) + def test_compactlist_getitem(self): # Get one item From 2edd67e3b9bf1c08a15eb49e9f064278482a3e06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 3 Dec 2015 12:17:28 -0500 Subject: [PATCH 054/135] Now use custom dictionnaries for data_per_streamline and data_per_point --- nibabel/streamlines/tests/test_tractogram.py | 40 ++++++++++ nibabel/streamlines/tests/test_trk.py | 37 --------- nibabel/streamlines/tractogram.py | 82 +++++++++++++++++--- 3 files changed, 110 insertions(+), 49 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 953765bc9b..f7d22e3782 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -237,6 +237,46 @@ def test_tractogram_copy(self): assert_arrays_equal(tractogram1.data_per_point['colors'], tractogram2.data_per_point['colors']) + def test_creating_invalid_tractogram(self): + # Not enough data_per_point for all the points of all streamlines. + scalars = [[(1, 0, 0)]*1, + [(0, 1, 0)]*2, + [(0, 0, 1)]*3] # Last streamlines has 5 points. + + assert_raises(ValueError, Tractogram, self.streamlines, + data_per_point={'scalars': scalars}) + + # Not enough data_per_streamline for all streamlines. + properties = [np.array([1.11, 1.22], dtype="f4"), + np.array([3.11, 3.22], dtype="f4")] + + assert_raises(ValueError, Tractogram, self.streamlines, + data_per_streamline={'properties': properties}) + + # Inconsistent dimension for a data_per_point. + scalars = [[(1, 0, 0)]*1, + [(0, 1)]*2, + [(0, 0, 1)]*5] + + assert_raises(ValueError, Tractogram, self.streamlines, + data_per_point={'scalars': scalars}) + + # Inconsistent dimension for a data_per_streamline. + properties = [[1.11, 1.22], + [2.11], + [3.11, 3.22]] + + assert_raises(ValueError, Tractogram, self.streamlines, + data_per_streamline={'properties': properties}) + + # Too many dimension for a data_per_streamline. + properties = [np.array([[1.11], [1.22]], dtype="f4"), + np.array([[2.11], [2.22]], dtype="f4"), + np.array([[3.11], [3.22]], dtype="f4")] + + assert_raises(ValueError, Tractogram, self.streamlines, + data_per_streamline={'properties': properties}) + class TestLazyTractogram(unittest.TestCase): diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index f689e53ef8..1a9b33fa6b 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -208,43 +208,6 @@ def test_write_complex_file(self): assert_equal(trk_file.read(), open(self.complex_trk_filename, 'rb').read()) - def test_write_erroneous_file(self): - # No scalars for every points - scalars = [[(1, 0, 0)], - [(0, 1, 0)], - [(0, 0, 1)]] - - tractogram = Tractogram(self.streamlines, - data_per_point={'scalars': scalars}) - trk = TrkFile(tractogram) - assert_raises(DataError, trk.save, BytesIO()) - - # No scalars for every streamlines - scalars = [[(1, 0, 0)]*1, - [(0, 1, 0)]*2] - - tractogram = Tractogram(self.streamlines, - data_per_point={'scalars': scalars}) - trk = TrkFile(tractogram) - assert_raises(IndexError, trk.save, BytesIO()) - - # Inconsistent number of properties - properties = [np.array([1.11, 1.22], dtype="f4"), - np.array([2.11], dtype="f4"), - np.array([3.11, 3.22], dtype="f4")] - tractogram = Tractogram(self.streamlines, - data_per_streamline={'properties': properties}) - trk = TrkFile(tractogram) - assert_raises(DataError, trk.save, BytesIO()) - - # No properties for every streamlines - properties = [np.array([1.11, 1.22], dtype="f4"), - np.array([2.11, 2.22], dtype="f4")] - tractogram = Tractogram(self.streamlines, - data_per_streamline={'properties': properties}) - trk = TrkFile(tractogram) - assert_raises(IndexError, trk.save, BytesIO()) - def test_load_write_file(self): for filename in [self.empty_trk_filename, self.simple_trk_filename, self.complex_trk_filename]: for lazy_load in [False, True]: diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 29fb975eee..e422c10ca5 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -62,6 +62,73 @@ class Tractogram(object): associated to each point (excluding the three coordinates). """ + class DataDict(collections.MutableMapping): + def __init__(self, tractogram, *args, **kwargs): + self.tractogram = tractogram + self.store = dict() + + # Use update to set the keys. + if len(args) == 1: + if isinstance(args[0], Tractogram.DataDict): + self.update(dict(args[0].store.items())) + elif args[0] is None: + return + else: + self.update(dict(*args, **kwargs)) + else: + self.update(dict(*args, **kwargs)) + + def __getitem__(self, key): + return self.store[key] + + def __delitem__(self, key): + del self.store[key] + + def __iter__(self): + return iter(self.store) + + def __len__(self): + return len(self.store) + + class DataPerStreamlineDict(DataDict): + """ Internal dictionary that makes sure data are 2D ndarray. """ + + def __setitem__(self, key, value): + value = np.asarray(value) + + if value.ndim == 1 and value.dtype != object: + # Reshape without copy + value.shape = ((len(value), 1)) + + if value.ndim != 2: + raise ValueError("data_per_streamline must be a 2D ndarray.") + + # We make sure there is the right amount of values + # (i.e. same as the number of streamlines in the tractogram). + if len(value) != len(self.tractogram): + msg = ("The number of values ({0}) should match the number of" + " streamlines ({1}).") + raise ValueError(msg.format(len(value), len(self.tractogram))) + + self.store[key] = value + + class DataPerPointDict(DataDict): + """ Internal dictionary that makes sure data are `CompactList`. """ + + def __setitem__(self, key, value): + value = CompactList(value) + + # We make sure we have the right amount of values (i.e. same as + # the total number of points of all streamlines in the tractogram). + if len(value._data) != len(self.tractogram.streamlines._data): + msg = ("The number of values ({0}) should match the total" + " number of points of all streamlines ({1}).") + nb_streamlines_points = self.tractogram.streamlines._data + raise ValueError(msg.format(len(value._data), + len(nb_streamlines_points))) + + self.store[key] = value + def __init__(self, streamlines=None, data_per_streamline=None, data_per_point=None): @@ -84,12 +151,8 @@ def data_per_streamline(self): @data_per_streamline.setter def data_per_streamline(self, value): - if value is None: - value = {} - - self._data_per_streamline = {} - for k, v in value.items(): - self._data_per_streamline[k] = np.asarray(v) + self._data_per_streamline = Tractogram.DataPerStreamlineDict(self, + value) @property def data_per_point(self): @@ -97,12 +160,7 @@ def data_per_point(self): @data_per_point.setter def data_per_point(self, value): - if value is None: - value = {} - - self._data_per_point = {} - for k, v in value.items(): - self._data_per_point[k] = CompactList(v) + self._data_per_point = Tractogram.DataPerPointDict(self, value) def __iter__(self): for i in range(len(self.streamlines)): From f99bc32d472bdb31394ffee283d1a137984fd218 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 4 Dec 2015 00:17:05 -0500 Subject: [PATCH 055/135] Removed TractogramHeader class --- nibabel/streamlines/__init__.py | 2 +- nibabel/streamlines/header.py | 122 ----------------------- nibabel/streamlines/tests/test_header.py | 37 ------- nibabel/streamlines/tractogram_file.py | 6 +- 4 files changed, 2 insertions(+), 165 deletions(-) delete mode 100644 nibabel/streamlines/tests/test_header.py diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 6cb4bdd87a..213d64297e 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,7 +1,7 @@ import os from ..externals.six import string_types -from .header import TractogramHeader +from .header import Field from .compact_list import CompactList from .tractogram import Tractogram, LazyTractogram diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 3fac6952bd..668d95ec78 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -1,8 +1,3 @@ -import copy -import numpy as np -from nibabel.orientations import aff2axcodes -from nibabel.externals import OrderedDict - class Field: """ Header fields common to multiple streamlines file formats. @@ -22,120 +17,3 @@ class Field: VOXEL_TO_RASMM = "voxel_to_rasmm" VOXEL_ORDER = "voxel_order" ENDIAN = "endian" - - -class TractogramHeader(object): - def __init__(self, hdr=None): - self._nb_streamlines = None - self._nb_scalars_per_point = None - self._nb_properties_per_streamline = None - self._to_world_space = np.eye(4) - self.extra = OrderedDict() - - if type(hdr) is dict: - if Field.NB_POINTS in hdr: - self.nb_streamlines = hdr[Field.NB_POINTS] - - if Field.NB_SCALARS_PER_POINT in hdr: - self.nb_scalars_per_point = hdr[Field.NB_SCALARS_PER_POINT] - - if Field.NB_PROPERTIES_PER_STREAMLINE in hdr: - self.nb_properties_per_streamline = hdr[Field.NB_PROPERTIES_PER_STREAMLINE] - - if Field.VOXEL_TO_RASMM in hdr: - self.to_world_space = hdr[Field.VOXEL_TO_RASMM] - - elif type(hdr) is TractogramHeader: - self._nb_streamlines = hdr._nb_streamlines - self._nb_scalars_per_point = hdr._nb_scalars_per_point - self._nb_properties_per_streamline = hdr._nb_properties_per_streamline - self._to_world_space = hdr._to_world_space - self.extra = copy.deepcopy(hdr.extra) - - @property - def to_world_space(self): - return self._to_world_space - - @to_world_space.setter - def to_world_space(self, value): - self._to_world_space = np.asarray(value, dtype=np.float32) - - @property - def voxel_sizes(self): - """ Get voxel sizes from to_world_space. """ - return np.sqrt(np.sum(self.to_world_space[:3, :3]**2, axis=0)) - - @voxel_sizes.setter - def voxel_sizes(self, value): - scaling = np.r_[np.array(value), [1]] - old_scaling = np.r_[np.array(self.voxel_sizes), [1]] - # Remove old scaling and apply new one - self.to_world_space = np.dot(np.diag(scaling/old_scaling), self.to_world_space) - - @property - def voxel_order(self): - """ Get voxel order from to_world_space. """ - return "".join(aff2axcodes(self.to_world_space)) - - @property - def nb_streamlines(self): - return self._nb_streamlines - - @nb_streamlines.setter - def nb_streamlines(self, value): - self._nb_streamlines = int(value) - - @property - def nb_scalars_per_point(self): - return self._nb_scalars_per_point - - @nb_scalars_per_point.setter - def nb_scalars_per_point(self, value): - self._nb_scalars_per_point = int(value) - - @property - def nb_properties_per_streamline(self): - return self._nb_properties_per_streamline - - @nb_properties_per_streamline.setter - def nb_properties_per_streamline(self, value): - self._nb_properties_per_streamline = int(value) - - @property - def extra(self): - return self._extra - - @extra.setter - def extra(self, value): - self._extra = OrderedDict(value) - - def copy(self): - header = TractogramHeader() - header._nb_streamlines = self.nb_streamlines - header.nb_scalars_per_point = self.nb_scalars_per_point - header.nb_properties_per_streamline = self.nb_properties_per_streamline - header.to_world_space = self.to_world_space.copy() - header.extra = copy.deepcopy(self.extra) - return header - - def __eq__(self, other): - return (np.allclose(self.to_world_space, other.to_world_space) and - self.nb_streamlines == other.nb_streamlines and - self.nb_scalars_per_point == other.nb_scalars_per_point and - self.nb_properties_per_streamline == other.nb_properties_per_streamline and - repr(self.extra) == repr(other.extra)) # Not the robust way, but will do! - - def __repr__(self): - txt = "Header{\n" - txt += "nb_streamlines: " + repr(self.nb_streamlines) + '\n' - txt += "nb_scalars_per_point: " + repr(self.nb_scalars_per_point) + '\n' - txt += "nb_properties_per_streamline: " + repr(self.nb_properties_per_streamline) + '\n' - txt += "to_world_space: " + repr(self.to_world_space) + '\n' - txt += "voxel_sizes: " + repr(self.voxel_sizes) + '\n' - - txt += "Extra fields: {\n" - for key in sorted(self.extra.keys()): - txt += " " + repr(key) + ": " + repr(self.extra[key]) + "\n" - - txt += " }\n" - return txt + "}" diff --git a/nibabel/streamlines/tests/test_header.py b/nibabel/streamlines/tests/test_header.py deleted file mode 100644 index 398195f615..0000000000 --- a/nibabel/streamlines/tests/test_header.py +++ /dev/null @@ -1,37 +0,0 @@ -import numpy as np - -from nose.tools import assert_equal, assert_true -from numpy.testing import assert_array_equal - -from nibabel.streamlines.header import TractogramHeader - - -def test_streamlines_header(): - header = TractogramHeader() - assert_true(header.nb_streamlines is None) - assert_true(header.nb_scalars_per_point is None) - assert_true(header.nb_properties_per_streamline is None) - assert_array_equal(header.voxel_sizes, (1, 1, 1)) - assert_array_equal(header.to_world_space, np.eye(4)) - assert_equal(header.extra, {}) - - # Modify simple attributes - header.nb_streamlines = 1 - header.nb_scalars_per_point = 2 - header.nb_properties_per_streamline = 3 - assert_equal(header.nb_streamlines, 1) - assert_equal(header.nb_scalars_per_point, 2) - assert_equal(header.nb_properties_per_streamline, 3) - - # Modifying voxel_sizes should be reflected in to_world_space - header.voxel_sizes = (2, 3, 4) - assert_array_equal(header.voxel_sizes, (2, 3, 4)) - assert_array_equal(np.diag(header.to_world_space), (2, 3, 4, 1)) - - # Modifying scaling of to_world_space should be reflected in voxel_sizes - header.to_world_space = np.diag([4, 3, 2, 1]) - assert_array_equal(header.voxel_sizes, (4, 3, 2)) - assert_array_equal(header.to_world_space, np.diag([4, 3, 2, 1])) - - # Test that we can run __repr__ without error. - repr(header) diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index 8bb1aa41ea..6b1524d9fd 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -1,6 +1,4 @@ -from abc import ABCMeta, abstractmethod, abstractproperty - -from .header import TractogramHeader +from abc import ABCMeta, abstractmethod class HeaderWarning(Warning): @@ -30,8 +28,6 @@ class TractogramFile(object): def __init__(self, tractogram, header=None): self._tractogram = tractogram self._header = {} if header is None else header - #self._header = TractogramHeader() if header is None else header - #self._header = TractogramHeader(header) @property def tractogram(self): From 1ef3f9878941f8caabbb83b57469468402455348 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 4 Dec 2015 01:08:01 -0500 Subject: [PATCH 056/135] Increased test coverage. --- nibabel/streamlines/__init__.py | 6 +- nibabel/streamlines/tests/test_streamlines.py | 7 + .../streamlines/tests/test_tractogram_file.py | 107 +++++++++++ nibabel/streamlines/tests/test_trk.py | 14 ++ nibabel/streamlines/tractogram_file.py | 8 +- nibabel/streamlines/trk.py | 173 +++++++++--------- 6 files changed, 220 insertions(+), 95 deletions(-) create mode 100644 nibabel/streamlines/tests/test_tractogram_file.py diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 213d64297e..5071e98dd7 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -90,7 +90,7 @@ def load(fileobj, lazy_load=False, ref=None): tractogram_file = detect_format(fileobj) if tractogram_file is None: - raise TypeError("Unknown format for 'fileobj': {}".format(fileobj)) + raise ValueError("Unknown format for 'fileobj': {}".format(fileobj)) return tractogram_file.load(fileobj, lazy_load=lazy_load) @@ -125,7 +125,7 @@ def save_tractogram(tractogram, filename, **kwargs): tractogram_file_class = detect_format(filename) if tractogram_file_class is None: - raise TypeError("Unknown tractogram file format: '{}'".format(filename)) + raise ValueError("Unknown tractogram file format: '{}'".format(filename)) tractogram_file = tractogram_file_class(tractogram, **kwargs) - tractogram_file.save(filename) + save(tractogram_file, filename) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index c21c688989..975ce79275 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -247,3 +247,10 @@ def test_save_complex_file(self): tfile = nib.streamlines.load(f, lazy_load=False) assert_tractogram_equal(tfile.tractogram, tractogram) + + def test_load_unknown_format(self): + assert_raises(ValueError, nib.streamlines.load, "") + + def test_save_unknown_format(self): + assert_raises(ValueError, + nib.streamlines.save_tractogram, Tractogram(), "") diff --git a/nibabel/streamlines/tests/test_tractogram_file.py b/nibabel/streamlines/tests/test_tractogram_file.py new file mode 100644 index 0000000000..e3466bb29a --- /dev/null +++ b/nibabel/streamlines/tests/test_tractogram_file.py @@ -0,0 +1,107 @@ +import os +import unittest +import tempfile +import numpy as np + +from os.path import join as pjoin + +import nibabel as nib +from nibabel.externals.six import BytesIO + +from nibabel.testing import clear_and_catch_warnings +from nibabel.testing import assert_arrays_equal, check_iteration +from nose.tools import assert_equal, assert_raises, assert_true, assert_false + +from .test_tractogram import assert_tractogram_equal +from ..tractogram_file import TractogramFile +from ..tractogram import Tractogram, LazyTractogram +from ..tractogram import UsageWarning +from .. import trk + + +def test_subclassing_tractogram_file(): + + # Missing 'save' method + class DummyTractogramFile(TractogramFile): + @classmethod + def get_magic_number(cls): + return False + + @classmethod + def support_data_per_point(cls): + return False + + @classmethod + def support_data_per_streamline(cls): + return False + + @classmethod + def is_correct_format(cls, fileobj): + return False + + @classmethod + def load(cls, fileobj, lazy_load=True): + return None + + assert_raises(TypeError, DummyTractogramFile, Tractogram()) + + # Missing 'load' method + class DummyTractogramFile(TractogramFile): + @classmethod + def get_magic_number(cls): + return False + + @classmethod + def support_data_per_point(cls): + return False + + @classmethod + def support_data_per_streamline(cls): + return False + + @classmethod + def is_correct_format(cls, fileobj): + return False + + def save(self, fileobj): + pass + + assert_raises(TypeError, DummyTractogramFile, Tractogram()) + + +def test_tractogram_file(): + assert_raises(NotImplementedError, TractogramFile.get_magic_number) + assert_raises(NotImplementedError, TractogramFile.is_correct_format, "") + assert_raises(NotImplementedError, TractogramFile.support_data_per_point) + assert_raises(NotImplementedError, TractogramFile.support_data_per_streamline) + assert_raises(NotImplementedError, TractogramFile.load, "") + + # Testing calling the 'save' method of `TractogramFile` object. + class DummyTractogramFile(TractogramFile): + @classmethod + def get_magic_number(cls): + return False + + @classmethod + def support_data_per_point(cls): + return False + + @classmethod + def support_data_per_streamline(cls): + return False + + @classmethod + def is_correct_format(cls, fileobj): + return False + + @classmethod + def load(cls, fileobj, lazy_load=True): + return None + + @classmethod + def save(self, fileobj): + pass + + assert_raises(NotImplementedError, + super(DummyTractogramFile, + DummyTractogramFile(Tractogram)).save, "") diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 1a9b33fa6b..5fb43fee3f 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -348,3 +348,17 @@ def test_write_scalars_and_properties_name_too_long(self): assert_raises(ValueError, trk.save, BytesIO()) else: trk.save(BytesIO()) + + def test_str(self): + trk = TrkFile.load(self.complex_trk_filename) + str(trk) # Simply test it's not failing when called. + + def test_read_buffer_size(self): + tmp = TrkFile.READ_BUFFER_SIZE + TrkFile.READ_BUFFER_SIZE = 1 + + for lazy_load in [False, True]: + trk = TrkFile.load(self.complex_trk_filename, lazy_load=lazy_load) + assert_tractogram_equal(trk.tractogram, self.complex_tractogram) + + TrkFile.READ_BUFFER_SIZE = tmp diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index 6b1524d9fd..253da833dd 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -50,22 +50,22 @@ def get_streamlines(self): def get_header(self): return self.header - @classmethod + @abstractclassmethod def get_magic_number(cls): ''' Returns streamlines file's magic number. ''' raise NotImplementedError() - @classmethod + @abstractclassmethod def support_data_per_point(cls): ''' Tells if this tractogram format supports saving data per point. ''' raise NotImplementedError() - @classmethod + @abstractclassmethod def support_data_per_streamline(cls): ''' Tells if this tractogram format supports saving data per streamline. ''' raise NotImplementedError() - @classmethod + @abstractclassmethod def is_correct_format(cls, fileobj): ''' Checks if the file has the right streamlines file format. Parameters diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 9b043ff821..4851d0ab6d 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -379,87 +379,6 @@ def write(self, tractogram): self.file.write(self.header.tostring()) -def create_compactlist_from_generator(gen): - BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. - - streamlines = CompactList() - scalars = CompactList() - properties = np.array([]) - - gen = iter(gen) - try: - first_element = next(gen) - gen = itertools.chain([first_element], gen) - except StopIteration: - return streamlines, scalars, properties - - # Allocated some buffer memory. - pts = np.asarray(first_element[0]) - scals = np.asarray(first_element[1]) - props = np.asarray(first_element[2]) - - scals_shape = scals.shape - props_shape = props.shape - - streamlines._data = np.empty((BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) - scalars._data = np.empty((BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) - properties = np.empty((BUFFER_SIZE, props.shape[0]), dtype=props.dtype) - - offset = 0 - for i, (pts, scals, props) in enumerate(gen): - pts = np.asarray(pts) - scals = np.asarray(scals) - props = np.asarray(props) - - if scals.shape[1] != scals_shape[1]: - raise ValueError("Number of scalars differs from one" - " point or streamline to another") - - if props.shape != props_shape: - raise ValueError("Number of properties differs from one" - " streamline to another") - - end = offset + len(pts) - if end >= len(streamlines._data): - # Resize is needed (at least `len(pts)` items will be added). - streamlines._data.resize((len(streamlines._data) + len(pts)+BUFFER_SIZE, pts.shape[1])) - scalars._data.resize((len(scalars._data) + len(scals)+BUFFER_SIZE, scals.shape[1])) - - streamlines._offsets.append(offset) - streamlines._lengths.append(len(pts)) - streamlines._data[offset:offset+len(pts)] = pts - scalars._data[offset:offset+len(scals)] = scals - - offset += len(pts) - - if i >= len(properties): - properties.resize((len(properties) + BUFFER_SIZE, props.shape[0])) - - properties[i] = props - - # Clear unused memory. - streamlines._data.resize((offset, pts.shape[1])) - - if scals_shape[1] == 0: - # Because resizing an empty ndarray creates memory! - scalars._data = np.empty((offset, scals.shape[1])) - else: - scalars._data.resize((offset, scals.shape[1])) - - # Share offsets and lengths between streamlines and scalars. - scalars._offsets = streamlines._offsets - scalars._lengths = streamlines._lengths - - if props_shape[0] == 0: - # Because resizing an empty ndarray creates memory! - properties = np.empty((i+1, props.shape[0])) - else: - properties.resize((i+1, props.shape[0])) - - return streamlines, scalars, properties - - - class TrkFile(TractogramFile): ''' Convenience class to encapsulate TRK file format. @@ -477,8 +396,9 @@ class TrkFile(TractogramFile): # Contants MAGIC_NUMBER = b"TRACK" HEADER_SIZE = 1000 + READ_BUFFER_SIZE = 10000000 # About 128 Mb if only no scalars nor properties. - def __init__(self, tractogram, header=None, ref=np.eye(4)): + def __init__(self, tractogram, header=None): """ Parameters ---------- @@ -488,9 +408,6 @@ def __init__(self, tractogram, header=None, ref=np.eye(4)): header : ``TractogramHeader`` file (optional) Metadata associated to this tractogram file. - ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) - Reference space where streamlines live in. - Notes ----- Streamlines of the tractogram are assumed to be in *RAS+* and *mm* space @@ -501,7 +418,6 @@ def __init__(self, tractogram, header=None, ref=np.eye(4)): header = dict(zip(header_rec.dtype.names, header_rec)) super(TrkFile, self).__init__(tractogram, header) - #self._affine = get_affine_from_reference(ref) @classmethod def get_magic_number(cls): @@ -541,6 +457,87 @@ def is_correct_format(cls, fileobj): return False + @classmethod + def _create_compactlist_from_generator(cls, gen): + """ Creates a CompactList object from a generator yielding tuples of + points, scalars and properties. """ + + streamlines = CompactList() + scalars = CompactList() + properties = np.array([]) + + gen = iter(gen) + try: + first_element = next(gen) + gen = itertools.chain([first_element], gen) + except StopIteration: + return streamlines, scalars, properties + + # Allocated some buffer memory. + pts = np.asarray(first_element[0]) + scals = np.asarray(first_element[1]) + props = np.asarray(first_element[2]) + + scals_shape = scals.shape + props_shape = props.shape + + streamlines._data = np.empty((cls.READ_BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) + scalars._data = np.empty((cls.READ_BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) + properties = np.empty((cls.READ_BUFFER_SIZE, props.shape[0]), dtype=props.dtype) + + offset = 0 + for i, (pts, scals, props) in enumerate(gen): + pts = np.asarray(pts) + scals = np.asarray(scals) + props = np.asarray(props) + + if scals.shape[1] != scals_shape[1]: + raise ValueError("Number of scalars differs from one" + " point or streamline to another") + + if props.shape != props_shape: + raise ValueError("Number of properties differs from one" + " streamline to another") + + end = offset + len(pts) + if end >= len(streamlines._data): + # Resize is needed (at least `len(pts)` items will be added). + streamlines._data.resize((len(streamlines._data) + len(pts)+cls.READ_BUFFER_SIZE, pts.shape[1])) + scalars._data.resize((len(scalars._data) + len(scals)+cls.READ_BUFFER_SIZE, scals.shape[1])) + + streamlines._offsets.append(offset) + streamlines._lengths.append(len(pts)) + streamlines._data[offset:offset+len(pts)] = pts + scalars._data[offset:offset+len(scals)] = scals + + offset += len(pts) + + if i >= len(properties): + properties.resize((len(properties) + cls.READ_BUFFER_SIZE, props.shape[0])) + + properties[i] = props + + # Clear unused memory. + streamlines._data.resize((offset, pts.shape[1])) + + if scals_shape[1] == 0: + # Because resizing an empty ndarray creates memory! + scalars._data = np.empty((offset, scals.shape[1])) + else: + scalars._data.resize((offset, scals.shape[1])) + + # Share offsets and lengths between streamlines and scalars. + scalars._offsets = streamlines._offsets + scalars._lengths = streamlines._lengths + + if props_shape[0] == 0: + # Because resizing an empty ndarray creates memory! + properties = np.empty((i+1, props.shape[0])) + else: + properties.resize((i+1, props.shape[0])) + + return streamlines, scalars, properties + @classmethod def load(cls, fileobj, lazy_load=False): ''' Loads streamlines from a file-like object. @@ -654,7 +651,7 @@ def _read(): tractogram = LazyTractogram.create_from(_read) else: - streamlines, scalars, properties = create_compactlist_from_generator(trk_reader) + streamlines, scalars, properties = cls._create_compactlist_from_generator(trk_reader) tractogram = Tractogram(streamlines) for scalar_name, slice_ in data_per_point_slice.items(): @@ -678,7 +675,7 @@ def _read(): #if tractogram.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: # raise HeaderError("'nb_properties_per_streamline' does not match.") - return cls(tractogram, header=trk_reader.header, ref=affine) + return cls(tractogram, header=trk_reader.header) def save(self, fileobj): ''' Saves tractogram to a file-like object using TRK format. From d0352ec00d96c26245696514e418d6cd9a522db4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 4 Dec 2015 01:17:25 -0500 Subject: [PATCH 057/135] Python3 fixes --- .../streamlines/tests/test_tractogram_file.py | 19 ++----------------- nibabel/streamlines/tractogram_file.py | 4 ++-- nibabel/streamlines/trk.py | 4 ++-- 3 files changed, 6 insertions(+), 21 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram_file.py b/nibabel/streamlines/tests/test_tractogram_file.py index e3466bb29a..10bb481e64 100644 --- a/nibabel/streamlines/tests/test_tractogram_file.py +++ b/nibabel/streamlines/tests/test_tractogram_file.py @@ -1,22 +1,7 @@ -import os -import unittest -import tempfile -import numpy as np +from nose.tools import assert_raises -from os.path import join as pjoin - -import nibabel as nib -from nibabel.externals.six import BytesIO - -from nibabel.testing import clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal, check_iteration -from nose.tools import assert_equal, assert_raises, assert_true, assert_false - -from .test_tractogram import assert_tractogram_equal +from ..tractogram import Tractogram from ..tractogram_file import TractogramFile -from ..tractogram import Tractogram, LazyTractogram -from ..tractogram import UsageWarning -from .. import trk def test_subclassing_tractogram_file(): diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index 253da833dd..45d09c1617 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -1,4 +1,5 @@ from abc import ABCMeta, abstractmethod +from nibabel.externals.six import with_metaclass class HeaderWarning(Warning): @@ -21,9 +22,8 @@ def __init__(self, callable): super(abstractclassmethod, self).__init__(callable) -class TractogramFile(object): +class TractogramFile(with_metaclass(ABCMeta)): ''' Convenience class to encapsulate tractogram file format. ''' - __metaclass__ = ABCMeta def __init__(self, tractogram, header=None): self._tractogram = tractogram diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 4851d0ab6d..f0bcc79cc9 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -715,9 +715,9 @@ def __str__(self): info += "\nvoxel_sizes: {0}".format(hdr[Field.VOXEL_SIZES]) info += "\norgin: {0}".format(hdr[Field.ORIGIN]) info += "\nnb_scalars: {0}".format(hdr[Field.NB_SCALARS_PER_POINT]) - info += "\nscalar_name:\n {0}".format("\n".join(hdr['scalar_name'])) + info += "\nscalar_name:\n {0}".format("\n".join(map(asstr, hdr['scalar_name']))) info += "\nnb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) - info += "\nproperty_name:\n {0}".format("\n".join(hdr['property_name'])) + info += "\nproperty_name:\n {0}".format("\n".join(map(asstr, hdr['property_name']))) info += "\nvox_to_world: {0}".format(hdr[Field.VOXEL_TO_RASMM]) info += "\nvoxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) info += "\nimage_orientation_patient: {0}".format(hdr['image_orientation_patient']) From 44c531ff7a638fa132a0b597c7a92e10a1f11f67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 11 Jan 2016 15:28:22 -0500 Subject: [PATCH 058/135] Increased test coverage. --- .../tests/data/complex_big_endian.trk | Bin 0 -> 1296 bytes nibabel/streamlines/tests/test_tractogram.py | 107 ++++++++++++++++++ nibabel/streamlines/tests/test_trk.py | 72 +++++++++++- nibabel/streamlines/trk.py | 25 ++-- 4 files changed, 188 insertions(+), 16 deletions(-) create mode 100644 nibabel/streamlines/tests/data/complex_big_endian.trk diff --git a/nibabel/streamlines/tests/data/complex_big_endian.trk b/nibabel/streamlines/tests/data/complex_big_endian.trk new file mode 100644 index 0000000000000000000000000000000000000000..ad9d0313978ab660629224bca8e1b0bfb83fa8a9 GIT binary patch literal 1296 zcmWFua&-1)U|LP3{z(}bLK3E1fW?HfOrE)21)MB znX|}z2X-!72R@N%pcpH Date: Mon, 11 Jan 2016 16:39:51 -0500 Subject: [PATCH 059/135] Added streamlines tests folder to the NiBabel packages --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index ff2a161980..5e9bf51c29 100755 --- a/setup.py +++ b/setup.py @@ -91,6 +91,7 @@ def main(**extra_args): 'nibabel.tests', 'nibabel.benchmarks', 'nibabel.streamlines', + 'nibabel.streamlines.tests', # install nisext as its own package 'nisext', 'nisext.tests'], From 688f66c3edab60cbeb6f2c3c35b56630388ff90e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 12 Jan 2016 21:44:58 -0500 Subject: [PATCH 060/135] Made the code numpy1.5 compliant --- nibabel/streamlines/tests/test_trk.py | 16 ++++++---------- nibabel/streamlines/trk.py | 8 ++++---- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index aed87302e9..5c412103ae 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -8,29 +8,25 @@ from nibabel.testing import suppress_warnings, clear_and_catch_warnings from nibabel.testing import assert_arrays_equal, check_iteration from nose.tools import assert_equal, assert_raises, assert_true +from numpy.testing import assert_array_equal from .test_tractogram import assert_tractogram_equal from ..tractogram import Tractogram, LazyTractogram from ..tractogram_file import DataError, HeaderError, HeaderWarning from .. import trk as trk_module -from ..trk import TrkFile, header_2_dtype +from ..trk import TrkFile from ..header import Field DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') def assert_header_equal(h1, h2): - header1 = np.zeros(1, dtype=header_2_dtype) - header2 = np.zeros(1, dtype=header_2_dtype) + for k in h1.keys(): + assert_array_equal(h2[k], h1[k]) - for k, v in h1.items(): - header1[k] = v - - for k, v in h2.items(): - header2[k] = v - - assert_equal(header1, header2) + for k in h2.keys(): + assert_array_equal(h1[k], h2[k]) class TestTRK(unittest.TestCase): diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 5644ac87f3..2ece43bab1 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -208,9 +208,9 @@ def create_empty_header(cls): #Default values header[Field.MAGIC_NUMBER] = TrkFile.MAGIC_NUMBER - header[Field.VOXEL_SIZES] = (1, 1, 1) - header[Field.DIMENSIONS] = (1, 1, 1) - header[Field.VOXEL_TO_RASMM] = np.eye(4) + header[Field.VOXEL_SIZES] = np.array((1, 1, 1), dtype="f4") + header[Field.DIMENSIONS] = np.array((1, 1, 1), dtype="h") + header[Field.VOXEL_TO_RASMM] = np.eye(4, dtype="f4") header[Field.VOXEL_ORDER] = b"RAS" header['version'] = 2 header['hdr_size'] = TrkFile.HEADER_SIZE @@ -416,7 +416,7 @@ def __init__(self, tractogram, header=None): """ if header is None: header_rec = TrkWriter.create_empty_header() - header = dict(zip(header_rec.dtype.names, header_rec)) + header = dict(zip(header_rec.dtype.names, header_rec[0])) super(TrkFile, self).__init__(tractogram, header) From b16530c85956950bd26db3f493284b9de186cf06 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 12 Jan 2016 23:30:44 -0500 Subject: [PATCH 061/135] Make code python2.6 compliant --- nibabel/streamlines/compact_list.py | 10 +++++----- nibabel/streamlines/tests/test_compact_list.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 8fb761f312..eade68309f 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -22,7 +22,7 @@ def __init__(self, iterable=None): memory is allocated. For an actual copy use the `.copy()` method. """ # Create new empty `CompactList` object. - self._data = None + self._data = np.array(0) self._offsets = [] self._lengths = [] @@ -54,13 +54,13 @@ def __init__(self, iterable=None): offset += len(e) # Clear unused memory. - if self._data is not None: + if self._data.ndim != 0: self._data.resize((offset,) + self.shape) @property def shape(self): """ Returns the matching shape of the elements in this compact list. """ - if self._data is None: + if self._data.ndim == 0: return None return self._data.shape[1:] @@ -79,7 +79,7 @@ def append(self, element): If you need to add multiple elements you should consider `CompactList.extend`. """ - if self._data is None: + if self._data.ndim == 0: self._data = np.asarray(element).copy() self._offsets.append(0) self._lengths.append(len(element)) @@ -103,7 +103,7 @@ def extend(self, elements): shape except for the first dimension. """ - if self._data is None: + if self._data.ndim == 0: elem = np.asarray(elements[0]) self._data = np.zeros((0, elem.shape[1]), dtype=elem.dtype) diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 188102ce85..794f11ce95 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -26,7 +26,7 @@ def test_creating_empty_compactlist(self): assert_equal(len(clist), 0) assert_equal(len(clist._offsets), 0) assert_equal(len(clist._lengths), 0) - assert_true(clist._data is None) + assert_equal(clist._data.ndim, 0) assert_true(clist.shape is None) def test_creating_compactlist_from_list(self): @@ -49,7 +49,7 @@ def test_creating_compactlist_from_list(self): assert_equal(len(clist), 0) assert_equal(len(clist._offsets), 0) assert_equal(len(clist._lengths), 0) - assert_true(clist._data is None) + assert_equal(clist._data.ndim, 0) assert_true(clist.shape is None) # Force CompactList constructor to use buffering. @@ -87,7 +87,7 @@ def test_creating_compactlist_from_generator(self): assert_equal(len(clist), 0) assert_equal(len(clist._offsets), 0) assert_equal(len(clist._lengths), 0) - assert_true(clist._data is None) + assert_equal(clist._data.ndim, 0) assert_true(clist.shape is None) def test_creating_compactlist_from_compact_list(self): From dfb7dec5569dd7f6ba5b4a53cc632e2958c41771 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 12 Jan 2016 23:43:47 -0500 Subject: [PATCH 062/135] Use nibabel's InTemporaryDirectory instead of tempfile.NamedTemporaryFile --- nibabel/streamlines/tests/test_streamlines.py | 67 ++++++++++--------- 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 975ce79275..10eb415748 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -7,9 +7,9 @@ import nibabel as nib from nibabel.externals.six import BytesIO +from nibabel.tmpdirs import InTemporaryDirectory from nibabel.testing import clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal, check_iteration from nose.tools import assert_equal, assert_raises, assert_true, assert_false from .test_tractogram import assert_tractogram_equal @@ -206,18 +206,20 @@ def test_load_complex_file(self): def test_save_empty_file(self): tractogram = Tractogram() for ext, cls in nib.streamlines.FORMATS.items(): - with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: - nib.streamlines.save_tractogram(tractogram, f.name) - tfile = nib.streamlines.load(f, lazy_load=False) - assert_tractogram_equal(tfile.tractogram, tractogram) + with InTemporaryDirectory(): + with open('streamlines' + ext, 'w+b') as f: + nib.streamlines.save_tractogram(tractogram, f.name) + tfile = nib.streamlines.load(f, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_simple_file(self): tractogram = Tractogram(self.streamlines) for ext, cls in nib.streamlines.FORMATS.items(): - with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: - nib.streamlines.save_tractogram(tractogram, f.name) - tfile = nib.streamlines.load(f, lazy_load=False) - assert_tractogram_equal(tfile.tractogram, tractogram) + with InTemporaryDirectory(): + with open('streamlines' + ext, 'w+b') as f: + nib.streamlines.save_tractogram(tractogram, f.name) + tfile = nib.streamlines.load(f, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_complex_file(self): complex_tractogram = Tractogram(self.streamlines, @@ -225,28 +227,31 @@ def test_save_complex_file(self): self.data_per_point) for ext, cls in nib.streamlines.FORMATS.items(): - with tempfile.NamedTemporaryFile(mode="w+b", suffix=ext) as f: - with clear_and_catch_warnings(record=True, modules=[trk]) as w: - nib.streamlines.save_tractogram(complex_tractogram, f.name) - - # If streamlines format does not support saving data per - # point or data per streamline, a warning message should - # be issued. - if not (cls.support_data_per_point() - and cls.support_data_per_streamline()): - assert_equal(len(w), 1) - assert_true(issubclass(w[0].category, UsageWarning)) - - tractogram = Tractogram(self.streamlines) - - if cls.support_data_per_point(): - tractogram.data_per_point = self.data_per_point - - if cls.support_data_per_streamline(): - tractogram.data_per_streamline = self.data_per_streamline - - tfile = nib.streamlines.load(f, lazy_load=False) - assert_tractogram_equal(tfile.tractogram, tractogram) + with InTemporaryDirectory(): + with open('streamlines' + ext, 'w+b') as f: + with clear_and_catch_warnings(record=True, + modules=[trk]) as w: + nib.streamlines.save_tractogram(complex_tractogram, + f.name) + + # If streamlines format does not support saving data per + # point or data per streamline, a warning message should + # be issued. + if not (cls.support_data_per_point() + and cls.support_data_per_streamline()): + assert_equal(len(w), 1) + assert_true(issubclass(w[0].category, UsageWarning)) + + tractogram = Tractogram(self.streamlines) + + if cls.support_data_per_point(): + tractogram.data_per_point = self.data_per_point + + if cls.support_data_per_streamline(): + tractogram.data_per_streamline = self.data_per_streamline + + tfile = nib.streamlines.load(f, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) def test_load_unknown_format(self): assert_raises(ValueError, nib.streamlines.load, "") From be2d6dfab7134accc84072035c9f7b923f2a2791 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 13 Jan 2016 21:10:16 -0500 Subject: [PATCH 063/135] Reduced read buffer from 128Mb to 4Mb --- nibabel/streamlines/compact_list.py | 2 +- nibabel/streamlines/trk.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index eade68309f..ef3faba8c4 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -6,7 +6,7 @@ class CompactList(object): the first dimension. """ - BUFFER_SIZE = 10000000 # About 128 Mb if item shape is 3. + BUFFER_SIZE = 349525 # About 4 Mb if item shape is 3 (e.g. 3D points). def __init__(self, iterable=None): """ diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 2ece43bab1..6c3b0cc754 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -397,7 +397,7 @@ class TrkFile(TractogramFile): # Contants MAGIC_NUMBER = b"TRACK" HEADER_SIZE = 1000 - READ_BUFFER_SIZE = 10000000 # About 128 Mb if only no scalars nor properties. + READ_BUFFER_SIZE = 349525 # About 4 Mb if there is no scalars nor properties. def __init__(self, tractogram, header=None): """ From 62d0681b2202a4fae83d8ea579feb639b4a441df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 13 Jan 2016 22:16:38 -0500 Subject: [PATCH 064/135] Increased read buffer from 4Mb to 8Mb --- nibabel/streamlines/compact_list.py | 2 +- nibabel/streamlines/trk.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index ef3faba8c4..457dc0da56 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -6,7 +6,7 @@ class CompactList(object): the first dimension. """ - BUFFER_SIZE = 349525 # About 4 Mb if item shape is 3 (e.g. 3D points). + BUFFER_SIZE = 87382*8 # About 8 Mb if item shape is 3 (e.g. 3D points). def __init__(self, iterable=None): """ diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 6c3b0cc754..49b07fa921 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -397,7 +397,7 @@ class TrkFile(TractogramFile): # Contants MAGIC_NUMBER = b"TRACK" HEADER_SIZE = 1000 - READ_BUFFER_SIZE = 349525 # About 4 Mb if there is no scalars nor properties. + READ_BUFFER_SIZE = 87382*8 # About 8 Mb if there is no scalars nor properties. def __init__(self, tractogram, header=None): """ From 661870041a857a1143645c65b123fd9ed4e8f05c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 14 Jan 2016 08:10:26 -0500 Subject: [PATCH 065/135] Decreased read buffer from 8Mb back to 4Mb --- nibabel/streamlines/compact_list.py | 2 +- nibabel/streamlines/trk.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 457dc0da56..93e4d08fbf 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -6,7 +6,7 @@ class CompactList(object): the first dimension. """ - BUFFER_SIZE = 87382*8 # About 8 Mb if item shape is 3 (e.g. 3D points). + BUFFER_SIZE = 87382*4 # About 4 Mb if item shape is 3 (e.g. 3D points). def __init__(self, iterable=None): """ diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 49b07fa921..464830d125 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -397,7 +397,7 @@ class TrkFile(TractogramFile): # Contants MAGIC_NUMBER = b"TRACK" HEADER_SIZE = 1000 - READ_BUFFER_SIZE = 87382*8 # About 8 Mb if there is no scalars nor properties. + READ_BUFFER_SIZE = 87382*4 # About 4 Mb if there is no scalars nor properties. def __init__(self, tractogram, header=None): """ From 565f064c038e41b2262a68f29d5df6818655879e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 28 Jan 2016 16:05:58 -0500 Subject: [PATCH 066/135] Addressed comments of @samuelstjean and @MrBago. Added a lazy option for apply_affine. --- nibabel/streamlines/__init__.py | 6 +- nibabel/streamlines/tests/test_tractogram.py | 51 ++++- nibabel/streamlines/tractogram.py | 185 ++++++++++++++----- nibabel/streamlines/tractogram_file.py | 5 + 4 files changed, 194 insertions(+), 53 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 5071e98dd7..42c0619531 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -5,9 +5,9 @@ from .compact_list import CompactList from .tractogram import Tractogram, LazyTractogram -from nibabel.streamlines.trk import TrkFile -#from nibabel.streamlines.tck import TckFile -#from nibabel.streamlines.vtk import VtkFile +from .trk import TrkFile +#from .tck import TckFile +#from .vtk import VtkFile # List of all supported formats FORMATS = {".trk": TrkFile, diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 62c793558a..68c8f61438 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -384,6 +384,49 @@ def test_creating_invalid_tractogram(self): assert_raises(ValueError, Tractogram, self.streamlines, data_per_streamline={'properties': properties}) + def test_tractogram_apply_affine(self): + # Create a tractogram with streamlines and other data. + tractogram = Tractogram( + self.streamlines, + data_per_streamline={'mean_curvature': self.mean_curvature, + 'mean_color': self.mean_color}, + data_per_point={'colors': self.colors}) + + affine = np.eye(4) + scaling = np.array((1, 2, 3), dtype=float) + affine[range(3), range(3)] = scaling + + # Apply the affine to the streamline in a lazy manner. + transformed_tractogram = tractogram.apply_affine(affine, lazy=True) + assert_true(type(transformed_tractogram) is LazyTractogram) + assert_true(check_iteration(transformed_tractogram)) + assert_equal(len(transformed_tractogram), len(self.streamlines)) + for s1, s2 in zip(transformed_tractogram.streamlines, + self.streamlines): + assert_array_almost_equal(s1, s2*scaling) + + for s1, s2 in zip(transformed_tractogram.streamlines, + tractogram.streamlines): + assert_array_almost_equal(s1, s2*scaling) + + assert_array_equal(transformed_tractogram.affine_to_rasmm, + np.dot(np.eye(4), np.linalg.inv(affine))) + + # Apply the affine to the streamlines in-place. + transformed_tractogram = tractogram.apply_affine(affine) + assert_true(transformed_tractogram is tractogram) + assert_true(check_iteration(transformed_tractogram)) + assert_equal(len(transformed_tractogram), len(self.streamlines)) + for s1, s2 in zip(transformed_tractogram.streamlines, + self.streamlines): + assert_array_almost_equal(s1, s2*scaling) + + # Apply affine again and check the affine_to_rasmm property. + transformed_tractogram = tractogram.apply_affine(affine) + assert_array_equal(transformed_tractogram.affine_to_rasmm, + np.dot(np.eye(4), np.dot(np.linalg.inv(affine), + np.linalg.inv(affine)))) + class TestLazyTractogram(unittest.TestCase): @@ -500,7 +543,7 @@ def test_lazy_tractogram_getitem(self): tractogram = LazyTractogram(streamlines, data_per_streamline=data_per_streamline, data_per_point=data_per_point) - assert_raises(AttributeError, tractogram.__getitem__, 0) + assert_raises(NotImplementedError, tractogram.__getitem__, 0) def test_lazy_tractogram_len(self): streamlines = lambda: (x for x in self.streamlines) @@ -570,6 +613,12 @@ def test_lazy_tractogram_apply_affine(self): for s1, s2 in zip(tractogram.streamlines, self.streamlines): assert_array_almost_equal(s1, s2*scaling) + # Apply affine again and check the affine_to_rasmm property. + transformed_tractogram = tractogram.apply_affine(affine) + assert_array_equal(transformed_tractogram.affine_to_rasmm, + np.dot(np.eye(4), np.dot(np.linalg.inv(affine), + np.linalg.inv(affine)))) + def test_lazy_tractogram_copy(self): # Create tractogram with streamlines and other data streamlines = lambda: (x for x in self.streamlines) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index e422c10ca5..8568fa4a21 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -44,22 +44,11 @@ class Tractogram(object): Tractogram objects have three main properties: ``streamlines`` - Parameters + Attributes ---------- - streamlines : list of ndarray of shape (Nt, 3) - Sequence of T streamlines. One streamline is an ndarray of shape - (Nt, 3) where Nt is the number of points of streamline t. - - data_per_streamline : dictionary of list of ndarray of shape (P,) - Sequence of T ndarrays of shape (P,) where T is the number of - streamlines defined by ``streamlines``, P is the number of properties - associated to each streamline. - - data_per_point : dictionary of list of ndarray of shape (Nt, M) - Sequence of T ndarrays of shape (Nt, M) where T is the number of - streamlines defined by ``streamlines``, Nt is the number of points - for a particular streamline t and M is the number of scalars - associated to each point (excluding the three coordinates). + affine_to_rasmm : 2D array (4,4) + Affine that brings the streamlines back to *RAS+* and *mm* space + where coordinate (0,0,0) refers to the center of the voxel. """ class DataDict(collections.MutableMapping): @@ -132,10 +121,30 @@ def __setitem__(self, key, value): def __init__(self, streamlines=None, data_per_streamline=None, data_per_point=None): + """ + Parameters + ---------- + streamlines : list of ndarray of shape (Nt, 3) (optional) + Sequence of T streamlines. One streamline is an ndarray of + shape (Nt, 3) where Nt is the number of points of streamline t. + + data_per_streamline : dict of list of ndarray of shape (P,) (optional) + Sequence of T ndarrays of shape (P,) where T is the number of + streamlines defined by ``streamlines``, P is the number of + properties associated to each streamline. + + data_per_point : dict of list of ndarray of shape (Nt, M) (optional) + Sequence of T ndarrays of shape (Nt, M) where T is the number + of streamlines defined by ``streamlines``, Nt is the number of + points for a particular streamline t and M is the number of + scalars associated to each point (excluding the three + coordinates). + """ self.streamlines = streamlines self.data_per_streamline = data_per_streamline self.data_per_point = data_per_point + self._affine_to_rasmm = np.eye(4) @property def streamlines(self): @@ -162,6 +171,11 @@ def data_per_point(self): def data_per_point(self, value): self._data_per_point = Tractogram.DataPerPointDict(self, value) + @property + def affine_to_rasmm(self): + # Return a copy. User should use self.apply_affine` to modify it. + return self._affine_to_rasmm.copy() + def __iter__(self): for i in range(len(self.streamlines)): yield self[i] @@ -198,26 +212,48 @@ def copy(self): tractogram = Tractogram(self.streamlines.copy(), data_per_streamline, data_per_point) + + tractogram._affine_to_rasmm = self.affine_to_rasmm return tractogram - def apply_affine(self, affine): + def apply_affine(self, affine, lazy=False): """ Applies an affine transformation on the points of each streamline. - This is performed in-place. + If `lazy` is not specified, this is performed *in-place*. Parameters ---------- affine : 2D array (4,4) - Transformation that will be applied on each streamline. + Transformation that will be applied to every streamline. + + Returns + ------- + tractogram : ``Tractogram`` or ``LazyTractogram`` object + Tractogram where the streamlines have been transformed according + to the given affine transformation. If the `lazy` option is true, + it returns a ``LazyTractogram`` object, otherwise it returns a + reference to this ``Tractogram`` object with updated streamlines. + """ + if lazy: + lazy_tractogram = LazyTractogram.from_tractogram(self) + lazy_tractogram.apply_affine(affine) + return lazy_tractogram + if len(self.streamlines) == 0: - return + return self BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. for i in range(0, len(self.streamlines._data), BUFFER_SIZE): pts = self.streamlines._data[i:i+BUFFER_SIZE] self.streamlines._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) + # Update the affine that brings back the streamlines to RASmm. + self._affine_to_rasmm = np.dot(self._affine_to_rasmm, + np.linalg.inv(affine)) + + return self + class LazyTractogram(Tractogram): ''' Class containing information about streamlines. @@ -227,28 +263,6 @@ class LazyTractogram(Tractogram): produce tuple of ``streamlines``, ``scalars`` and ``properties`` for each streamline. - Parameters - ---------- - streamlines_func : coroutine ouputting (Nt,3) array-like (optional) - Function yielding streamlines. One streamline is - an ndarray of shape (Nt,3) where Nt is the number of points of - streamline t. - - scalars_func : coroutine ouputting (Nt,M) array-like (optional) - Function yielding scalars for a particular streamline t. The scalars - are represented as an ndarray of shape (Nt,M) where Nt is the number - of points of that streamline t and M is the number of scalars - associated to each point (excluding the three coordinates). - - properties_func : coroutine ouputting (P,) array-like (optional) - Function yielding properties for a particular streamline t. The - properties are represented as an ndarray of shape (P,) where P is - the number of properties associated to each streamline. - - getitem_func : function `idx -> 3-tuples` (optional) - Function returning a subset of the tractogram given an index or a - slice (i.e. the __getitem__ function to use). - Notes ----- If provided, ``scalars`` and ``properties`` must yield the same number of @@ -285,26 +299,84 @@ def __iter__(self): def __len__(self): return len(self.store) - def __init__(self, streamlines=None, data_per_streamline=None, + def __init__(self, streamlines=None, + data_per_streamline=None, data_per_point=None): - super(LazyTractogram, self).__init__(streamlines, data_per_streamline, + """ + Parameters + ---------- + streamlines : coroutine yielding ndarrays of shape (Nt,3) (optional) + Function yielding streamlines. One streamline is an ndarray of + shape (Nt,3) where Nt is the number of points of streamline t. + + data_per_streamline : dict of coroutines yielding ndarrays of shape (P,) (optional) + Function yielding properties for a particular streamline t. The + properties are represented as an ndarray of shape (P,) where P is + the number of properties associated to each streamline. + + data_per_point : dict of coroutines yielding ndarrays of shape (Nt,M) (optional) + Function yielding scalars for a particular streamline t. The + scalars are represented as an ndarray of shape (Nt,M) where Nt + is the number of points of that streamline t and M is the number + of scalars associated to each point (excluding the three + coordinates). + + """ + super(LazyTractogram, self).__init__(streamlines, + data_per_streamline, data_per_point) self._nb_streamlines = None self._data = None self._affine_to_apply = np.eye(4) + @classmethod + def from_tractogram(cls, tractogram): + ''' Creates a ``LazyTractogram`` object from a ``Tractogram`` object. + + Parameters + ---------- + tractogram : ``Tractgogram`` object + Tractogram from which to create a ``LazyTractogram`` object. + + Returns + ------- + lazy_tractogram : ``LazyTractogram`` object + New lazy tractogram. + + ''' + data_per_streamline = {} + for key, value in tractogram.data_per_streamline.items(): + data_per_streamline[key] = lambda: value + + data_per_point = {} + for key, value in tractogram.data_per_point.items(): + data_per_point[key] = lambda: value + + lazy_tractogram = cls(lambda: tractogram.streamlines.copy(), + data_per_streamline, + data_per_point) + + lazy_tractogram._nb_streamlines = len(tractogram) + lazy_tractogram._affine_to_rasmm = tractogram.affine_to_rasmm + return lazy_tractogram + @classmethod def create_from(cls, data_func): - ''' Creates a `LazyTractogram` from a coroutine yielding - `TractogramItem` objects. + ''' Creates a ``LazyTractogram`` from a coroutine yielding + ``TractogramItem`` objects. Parameters ---------- - data_func : coroutine yielding `TractogramItem` objects + data_func : coroutine yielding ``TractogramItem`` objects A function that whenever it is called starts yielding - `TractogramItem` objects that should be part of this + ``TractogramItem`` objects that should be part of this LazyTractogram. + Returns + ------- + lazy_tractogram : ``LazyTractogram`` object + New lazy tractogram. + ''' if not callable(data_func): raise TypeError("`data_func` must be a coroutine.") @@ -411,7 +483,7 @@ def _gen_data(): return _gen_data() def __getitem__(self, idx): - raise AttributeError('`LazyTractogram` does not support indexing.') + raise NotImplementedError('`LazyTractogram` does not support indexing.') def __iter__(self): i = 0 @@ -446,11 +518,26 @@ def copy(self): return tractogram def apply_affine(self, affine): - """ Applies an affine transformation on the streamlines. + """ Applies an affine transformation to the streamlines. + + The transformation will be applied just before returning the + streamlines. Parameters ---------- affine : 2D array (4,4) Transformation that will be applied on each streamline. + + Returns + ------- + lazy_tractogram : ``LazyTractogram`` object + Reference to this instance of ``LazyTractogram``. + """ + # Update the affine that will be applied when returning streamlines. self._affine_to_apply = np.dot(affine, self._affine_to_apply) + + # Update the affine that brings back the streamlines to RASmm. + self._affine_to_rasmm = np.dot(self._affine_to_rasmm, + np.linalg.inv(affine)) + return self diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index 45d09c1617..d6d7fb693e 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -68,12 +68,14 @@ def support_data_per_streamline(cls): @abstractclassmethod def is_correct_format(cls, fileobj): ''' Checks if the file has the right streamlines file format. + Parameters ---------- fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the beginning of the header). + Returns ------- is_correct_format : boolean @@ -84,6 +86,7 @@ def is_correct_format(cls, fileobj): @abstractclassmethod def load(cls, fileobj, lazy_load=True): ''' Loads streamlines from a file-like object. + Parameters ---------- fileobj : string or file-like object @@ -93,6 +96,7 @@ def load(cls, fileobj, lazy_load=True): lazy_load : boolean (optional) Load streamlines in a lazy manner i.e. they will not be kept in memory. For postprocessing speed, turn off this option. + Returns ------- tractogram_file : ``TractogramFile`` object @@ -104,6 +108,7 @@ def load(cls, fileobj, lazy_load=True): @abstractmethod def save(self, fileobj): ''' Saves streamlines to a file-like object. + Parameters ---------- fileobj : string or file-like object From 6b8c115199383858670e3ef1baec3a7539c19931 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 29 Jan 2016 07:51:59 -0500 Subject: [PATCH 067/135] Added property 'affine' to TractogramFile. --- nibabel/streamlines/tests/test_trk.py | 1 + nibabel/streamlines/tractogram_file.py | 10 ++++++++++ 2 files changed, 11 insertions(+) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 5c412103ae..80af2c02c0 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -147,6 +147,7 @@ def test_tractogram_file_properties(self): assert_equal(trk.get_streamlines(), trk.streamlines) assert_equal(trk.get_tractogram(), trk.tractogram) assert_equal(trk.get_header(), trk.header) + assert_array_equal(trk.get_affine(), trk.header[Field.VOXEL_TO_RASMM]) def test_write_empty_file(self): tractogram = Tractogram() diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index d6d7fb693e..9eeb2a303c 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -1,6 +1,8 @@ from abc import ABCMeta, abstractmethod from nibabel.externals.six import with_metaclass +from .header import Field + class HeaderWarning(Warning): pass @@ -41,6 +43,10 @@ def streamlines(self): def header(self): return self._header + @property + def affine(self): + return self.header.get(Field.VOXEL_TO_RASMM) + def get_tractogram(self): return self.tractogram @@ -50,6 +56,10 @@ def get_streamlines(self): def get_header(self): return self.header + def get_affine(self): + """ Returns vox -> rasmm affine. """ + return self.affine + @abstractclassmethod def get_magic_number(cls): ''' Returns streamlines file's magic number. ''' From b6acfccb1cc99c61a4388b787a0b6f1be9b0d06e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 29 Jan 2016 22:42:18 -0500 Subject: [PATCH 068/135] Added check for the vox_to_ras affine. --- nibabel/streamlines/tests/test_trk.py | 18 ++++ nibabel/streamlines/trk.py | 132 ++++++++++++++++---------- 2 files changed, 100 insertions(+), 50 deletions(-) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 80af2c02c0..b88c1d2634 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -111,6 +111,24 @@ def test_load_file_with_wrong_information(self): trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) assert_tractogram_equal(trk.tractogram, self.simple_tractogram) + # Simulate a TRK where `vox_to_ras` is not recorded (i.e. all zeros). + vox_to_ras = np.zeros((4, 4), dtype=np.float32).tostring() + new_trk_file = trk_file[:440] + vox_to_ras + trk_file[440+64:] + with clear_and_catch_warnings(record=True, modules=[trk_module]) as w: + trk = TrkFile.load(BytesIO(new_trk_file)) + assert_equal(len(w), 1) + assert_true(issubclass(w[0].category, HeaderWarning)) + assert_true("identity" in str(w[0].message)) + assert_array_equal(trk.affine, np.eye(4)) + + # Simulate a TRK where `vox_to_ras` is invalid. + vox_to_ras = np.zeros((4, 4), dtype=np.float32) + vox_to_ras[3, 3] = 1 + vox_to_ras = vox_to_ras.tostring() + new_trk_file = trk_file[:440] + vox_to_ras + trk_file[440+64:] + with clear_and_catch_warnings(record=True, modules=[trk_module]) as w: + assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + # Simulate a TRK file where `voxel_order` was not provided. voxel_order = np.zeros(1, dtype="|S3").tostring() new_trk_file = trk_file[:948] + voxel_order + trk_file[948+3:] diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 464830d125..b96ce08c6c 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -1,6 +1,6 @@ from __future__ import division -# Documentation available here: +# Definition of trackvis header structure: # http://www.trackvis.org/docs/?subsect=fileformat import os @@ -15,6 +15,7 @@ from nibabel.openers import Opener from nibabel.py3k import asbytes, asstr from nibabel.volumeutils import (native_code, swapped_code) +from nibabel.orientations import (aff2axcodes, axcodes2ornt) from .compact_list import CompactList from .tractogram_file import TractogramFile @@ -26,8 +27,6 @@ MAX_NB_NAMED_SCALARS_PER_POINT = 10 MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE = 10 -# Definition of trackvis header structure. -# See http://www.trackvis.org/docs/?subsect=fileformat # See http://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html header_1_dtd = [(Field.MAGIC_NUMBER, 'S6'), (Field.DIMENSIONS, 'h', 3), @@ -65,7 +64,7 @@ ('scalar_name', 'S20', MAX_NB_NAMED_SCALARS_PER_POINT), (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), ('property_name', 'S20', MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), - (Field.VOXEL_TO_RASMM, 'f4', (4, 4)), # new field for version 2 + (Field.VOXEL_TO_RASMM, 'f4', (4, 4)), # New in version 2. ('reserved', 'S444'), (Field.VOXEL_ORDER, 'S4'), ('pad2', 'S4'), @@ -121,12 +120,15 @@ def __init__(self, fileobj): self.endianness = swapped_code # Swap byte order - header_rec = header_rec.newbyteorder()#np.array(header_rec, dtype=header_rec.newbyteorder().dtype) + header_rec = header_rec.newbyteorder() if header_rec['hdr_size'] != TrkFile.HEADER_SIZE: - raise HeaderError('Invalid hdr_size: {0} instead of {1}'.format(header_rec['hdr_size'], TrkFile.HEADER_SIZE)) + msg = "Invalid hdr_size: {0} instead of {1}" + raise HeaderError(msg.format(header_rec['hdr_size'], + TrkFile.HEADER_SIZE)) if header_rec['version'] == 1: - header_rec = np.fromstring(string=header_str, dtype=header_1_dtype) + header_rec = np.fromstring(string=header_str, + dtype=header_1_dtype) elif header_rec['version'] == 2: pass # Nothing more to do else: @@ -135,12 +137,28 @@ def __init__(self, fileobj): # Convert the first record of `header_rec` into a dictionnary self.header = dict(zip(header_rec.dtype.names, header_rec[0])) + # If vox_to_ras[3][3] is 0, it means the matrix is not recorded. + if self.header[Field.VOXEL_TO_RASMM][3][3] == 0: + self.header[Field.VOXEL_TO_RASMM] = np.eye(4, dtype=np.float32) + warnings.warn(("Field 'vox_to_ras' in the TRK's header was" + " not recorded. Will continue assuming it's" + " the identity."), HeaderWarning) + + # Check that the 'vox_to_ras' affine is valid, i.e. should be + # able to determine the axis directions. + axcodes = aff2axcodes(self.header[Field.VOXEL_TO_RASMM]) + if None in axcodes: + msg = ("The 'vox_to_ras' affine is invalid! Could not" + " determine the axis directions from it.\n{0}" + ).format(self.header[Field.VOXEL_TO_RASMM]) + raise HeaderError(msg) + # By default, the voxel order is LPS. # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates if self.header[Field.VOXEL_ORDER] == b"": - warnings.warn(("Voxel order is not specified, will assume" - " 'LPS' since it is Trackvis software's" - " default."), HeaderWarning) + msg = ("Voxel order is not specified, will assume 'LPS' since" + "it is Trackvis software's default.") + warnings.warn(msg, HeaderWarning) self.header[Field.VOXEL_ORDER] = b"LPS" # Keep the file position where the data begin. @@ -160,7 +178,8 @@ def __iter__(self): # Set the file position at the beginning of the data. f.seek(self.offset_data, os.SEEK_SET) - # If 'count' field is 0, i.e. not provided, we have to loop until the EOF. + # If 'count' field is 0, i.e. not provided, we have to loop + # until the EOF. nb_streamlines = self.header[Field.NB_STREAMLINES] if nb_streamlines == 0: nb_streamlines = np.inf @@ -206,7 +225,7 @@ def create_empty_header(cls): ''' Return an empty compliant TRK header. ''' header = np.zeros(1, dtype=header_2_dtype) - #Default values + # Default values header[Field.MAGIC_NUMBER] = TrkFile.MAGIC_NUMBER header[Field.VOXEL_SIZES] = np.array((1, 1, 1), dtype="f4") header[Field.DIMENSIONS] = np.array((1, 1, 1), dtype="h") @@ -260,47 +279,65 @@ def write(self, tractogram): self.file.write(self.header.tostring()) return - # Update the 'property_name' field using 'data_per_streamline' of the tractogram. + # Update the 'property_name' field using 'data_per_streamline' of the + # tractogram. data_for_streamline = first_item.data_for_streamline if len(data_for_streamline) > MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: - raise ValueError("Can only store {0} named data_per_streamline (properties).".format(MAX_NB_NAMED_SCALARS_PER_POINT)) + raise ValueError(("Can only store {0} named data_per_streamline" + " (also known as 'properties' in the TRK" + " format).").format(MAX_NB_NAMED_SCALARS_PER_POINT)) data_for_streamline_keys = sorted(data_for_streamline.keys()) - self.header['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') + self.header['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, + dtype='S20') for i, k in enumerate(data_for_streamline_keys): nb_values = data_for_streamline[k].shape[0] if len(k) > 20: - raise ValueError("Property name '{0}' is too long (max 20 char.)".format(k)) + raise ValueError(("Property name '{0}' is too long (max 20" + "characters.)").format(k)) elif len(k) > 18 and nb_values > 1: - raise ValueError("Property name '{0}' is too long (need to be less than 18 characters when storing more than one value".format(k)) + raise ValueError(("Property name '{0}' is too long (need to be" + " less than 18 characters when storing more" + " than one value").format(k)) property_name = k if nb_values > 1: - # Use the last to bytes of the name to store the nb of values associated to this data_for_streamline. - property_name = asbytes(k[:18].ljust(18, '\x00')) + b'\x00' + np.array(nb_values, dtype=np.int8).tostring() + # Use the last to bytes of the name to store the nb of values + # associated to this data_for_streamline. + property_name = asbytes(k[:18].ljust(18, '\x00')) + b'\x00' \ + + np.array(nb_values, dtype=np.int8).tostring() self.header['property_name'][i] = property_name - # Update the 'scalar_name' field using 'data_per_point' of the tractogram. + # Update the 'scalar_name' field using 'data_per_point' of the + # tractogram. data_for_points = first_item.data_for_points if len(data_for_points) > MAX_NB_NAMED_SCALARS_PER_POINT: - raise ValueError("Can only store {0} named data_per_point (scalars).".format(MAX_NB_NAMED_SCALARS_PER_POINT)) + raise ValueError(("Can only store {0} named data_per_point (also" + " known as 'scalars' in the TRK format)." + ).format(MAX_NB_NAMED_SCALARS_PER_POINT)) data_for_points_keys = sorted(data_for_points.keys()) - self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') + self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, + dtype='S20') for i, k in enumerate(data_for_points_keys): nb_values = data_for_points[k].shape[1] if len(k) > 20: - raise ValueError("Scalar name '{0}' is too long (max 18 char.)".format(k)) + raise ValueError(("Scalar name '{0}' is too long (max 18" + " characters.)").format(k)) elif len(k) > 18 and nb_values > 1: - raise ValueError("Scalar name '{0}' is too long (need to be less than 18 characters when storing more than one value".format(k)) + raise ValueError(("Scalar name '{0}' is too long (need to be" + " less than 18 characters when storing more" + " than one value").format(k)) scalar_name = k if nb_values > 1: - # Use the last to bytes of the name to store the nb of values associated to this data_for_streamline. - scalar_name = asbytes(k[:18].ljust(18, '\x00')) + b'\x00' + np.array(nb_values, dtype=np.int8).tostring() + # Use the last to bytes of the name to store the nb of values + # associated to this data_for_streamline. + scalar_name = asbytes(k[:18].ljust(18, '\x00')) + b'\x00' \ + + np.array(nb_values, dtype=np.int8).tostring() self.header['scalar_name'][i] = scalar_name @@ -311,15 +348,16 @@ def write(self, tractogram): # Applied the inverse of the affine found in the TRK header. # rasmm -> voxel - affine = np.dot(np.linalg.inv(self.header[Field.VOXEL_TO_RASMM]), affine) + affine = np.dot(np.linalg.inv(self.header[Field.VOXEL_TO_RASMM]), + affine) # If the voxel order implied by the affine does not match the voxel # order in the TRK header, change the orientation. # voxel (affine) -> voxel (header) header_ornt = asstr(self.header[Field.VOXEL_ORDER]) - affine_ornt = "".join(nib.orientations.aff2axcodes(self.header[Field.VOXEL_TO_RASMM])) - header_ornt = nib.orientations.axcodes2ornt(header_ornt) - affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) + affine_ornt = "".join(aff2axcodes(self.header[Field.VOXEL_TO_RASMM])) + header_ornt = axcodes2ornt(header_ornt) + affine_ornt = axcodes2ornt(affine_ornt) ornt = nib.orientations.ornt_transform(affine_ornt, header_ornt) M = nib.orientations.inv_ornt_aff(ornt, self.header[Field.DIMENSIONS]) affine = np.dot(M, affine) @@ -366,10 +404,12 @@ def write(self, tractogram): # Check for errors if nb_scalars_per_point != int(nb_scalars_per_point): - raise DataError("Nb. of scalars differs from one point to another!") + msg = "Nb. of scalars differs from one point to another!" + raise DataError(msg) if nb_properties_per_streamline != int(nb_properties_per_streamline): - raise DataError("Nb. of properties differs from one streamline to another!") + msg = "Nb. of properties differs from one streamline to another!" + raise DataError(msg) self.header[Field.NB_STREAMLINES] = self.nb_streamlines self.header[Field.NB_SCALARS_PER_POINT] = nb_scalars_per_point @@ -411,8 +451,9 @@ def __init__(self, tractogram, header=None): Notes ----- - Streamlines of the tractogram are assumed to be in *RAS+* and *mm* space - where coordinate (0,0,0) refers to the center of the voxel. + Streamlines of the tractogram are assumed to be in *RAS+* + and *mm* space where coordinate (0,0,0) refers to the center + of the voxel. """ if header is None: header_rec = TrkWriter.create_empty_header() @@ -589,9 +630,9 @@ def load(cls, fileobj, lazy_load=False): # order in the TRK header, change the orientation. # voxel (header) -> voxel (affine) header_ornt = asstr(trk_reader.header[Field.VOXEL_ORDER]) - affine_ornt = "".join(nib.orientations.aff2axcodes(trk_reader.header[Field.VOXEL_TO_RASMM])) - header_ornt = nib.orientations.axcodes2ornt(header_ornt) - affine_ornt = nib.orientations.axcodes2ornt(affine_ornt) + affine_ornt = "".join(aff2axcodes(trk_reader.header[Field.VOXEL_TO_RASMM])) + header_ornt = axcodes2ornt(header_ornt) + affine_ornt = axcodes2ornt(affine_ornt) ornt = nib.orientations.ornt_transform(header_ornt, affine_ornt) M = nib.orientations.inv_ornt_aff(ornt, trk_reader.header[Field.DIMENSIONS]) affine = np.dot(M, affine) @@ -609,7 +650,8 @@ def load(cls, fileobj, lazy_load=False): if len(scalar_name) == 0: continue - # Check if we encoded the number of values we stocked for this scalar name. + # Check if we encoded the number of values we stocked for this + # scalar name. nb_scalars = 1 if scalar_name[-2] == '\x00' and scalar_name[-1] != '\x00': nb_scalars = int(np.fromstring(scalar_name[-1], np.int8)) @@ -629,7 +671,8 @@ def load(cls, fileobj, lazy_load=False): if len(property_name) == 0: continue - # Check if we encoded the number of values we stocked for this property name. + # Check if we encoded the number of values we stocked for this + # property name. nb_properties = 1 if property_name[-2] == '\x00' and property_name[-1] != '\x00': nb_properties = int(np.fromstring(property_name[-1], np.int8)) @@ -641,7 +684,6 @@ def load(cls, fileobj, lazy_load=False): if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: data_per_streamline_slice['properties'] = slice(cpt, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]) - if lazy_load: def _read(): for pts, scals, props in trk_reader: @@ -668,14 +710,6 @@ def _read(): # Bring tractogram to RAS+ and mm space tractogram.apply_affine(affine.astype(np.float32)) - ## Perform some integrity checks - #if tractogram.header.voxel_sizes != trk_reader.header[Field.VOXEL_SIZES]: - # raise HeaderError("'voxel_sizes' does not match the affine.") - #if tractogram.header.nb_scalars_per_point != trk_reader.header[Field.NB_SCALARS_PER_POINT]: - # raise HeaderError("'nb_scalars_per_point' does not match.") - #if tractogram.header.nb_properties_per_streamline != trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: - # raise HeaderError("'nb_properties_per_streamline' does not match.") - return cls(tractogram, header=trk_reader.header) def save(self, fileobj): @@ -706,7 +740,6 @@ def __str__(self): info : string Header information relevant to the TRK format. ''' - #trk_reader = TrkReader(fileobj) hdr = self.header info = "" @@ -732,6 +765,5 @@ def __str__(self): info += "\nswap_zx: {0}".format(hdr['swap_zx']) info += "\nn_count: {0}".format(hdr[Field.NB_STREAMLINES]) info += "\nhdr_size: {0}".format(hdr['hdr_size']) - #info += "endianess: {0}".format(hdr[Field.ENDIAN]) return info From 9d1b5a87179970c2824871d4773796d5ec148678 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 4 Feb 2016 02:02:17 -0500 Subject: [PATCH 069/135] Changed type of CompactList's _offsets and _lengths from list to ndarray --- nibabel/streamlines/compact_list.py | 53 +++++++++++-------- .../streamlines/tests/test_compact_list.py | 35 ++++++------ nibabel/streamlines/trk.py | 9 +++- 3 files changed, 58 insertions(+), 39 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 93e4d08fbf..edfa153de1 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -23,8 +23,8 @@ def __init__(self, iterable=None): """ # Create new empty `CompactList` object. self._data = np.array(0) - self._offsets = [] - self._lengths = [] + self._offsets = np.array([], dtype=int) + self._lengths = np.array([], dtype=int) if isinstance(iterable, CompactList): # Create a view. @@ -33,6 +33,8 @@ def __init__(self, iterable=None): self._lengths = iterable._lengths elif iterable is not None: + offsets = [] + lengths = [] # Initialize the `CompactList` object from iterable's item. offset = 0 for i, e in enumerate(iterable): @@ -48,11 +50,14 @@ def __init__(self, iterable=None): nb_points += len(e) + CompactList.BUFFER_SIZE self._data.resize((nb_points,) + self.shape) - self._offsets.append(offset) - self._lengths.append(len(e)) + offsets.append(offset) + lengths.append(len(e)) self._data[offset:offset+len(e)] = e offset += len(e) + self._offsets = np.asarray(offsets) + self._lengths = np.asarray(lengths) + # Clear unused memory. if self._data.ndim != 0: self._data.resize((offset,) + self.shape) @@ -81,16 +86,16 @@ def append(self, element): """ if self._data.ndim == 0: self._data = np.asarray(element).copy() - self._offsets.append(0) - self._lengths.append(len(element)) + self._offsets = np.array([0]) + self._lengths = np.array([len(element)]) return if element.shape[1:] != self.shape: raise ValueError("All dimensions, except the first one," " must match exactly") - self._offsets.append(len(self._data)) - self._lengths.append(len(element)) + self._offsets = np.r_[self._offsets, len(self._data)] + self._lengths = np.r_[self._lengths, len(element)] self._data = np.append(self._data, element, axis=0) def extend(self, elements): @@ -113,17 +118,20 @@ def extend(self, elements): self._data.resize((self._data.shape[0]+sum(elements._lengths), self._data.shape[1])) + offsets = [] for offset, length in zip(elements._offsets, elements._lengths): - self._offsets.append(next_offset) - self._lengths.append(length) + offsets.append(next_offset) self._data[next_offset:next_offset+length] = elements._data[offset:offset+length] next_offset += length + self._lengths = np.r_[self._lengths, elements._lengths] + self._offsets = np.r_[self._offsets, offsets] + else: self._data = np.concatenate([self._data] + list(elements), axis=0) lengths = list(map(len, elements)) - self._lengths.extend(lengths) - self._offsets.extend(np.cumsum([next_offset] + lengths).tolist()[:-1]) + self._lengths = np.r_[self._lengths, lengths] + self._offsets = np.r_[self._offsets, np.cumsum([next_offset] + lengths)[:-1]] def copy(self): """ Creates a copy of this ``CompactList`` object. """ @@ -135,12 +143,15 @@ def copy(self): clist._data = np.empty((total_lengths,) + self._data.shape[1:], dtype=self._data.dtype) - idx = 0 + next_offset = 0 + offsets = [] for offset, length in zip(self._offsets, self._lengths): - clist._offsets.append(idx) - clist._lengths.append(length) - clist._data[idx:idx+length] = self._data[offset:offset+length] - idx += length + offsets.append(next_offset) + clist._data[next_offset:next_offset+length] = self._data[offset:offset+length] + next_offset += length + + clist._offsets = np.asarray(offsets) + clist._lengths = self._lengths.copy() return clist @@ -172,8 +183,8 @@ def __getitem__(self, idx): elif isinstance(idx, list): clist = CompactList() clist._data = self._data - clist._offsets = [self._offsets[i] for i in idx] - clist._lengths = [self._lengths[i] for i in idx] + clist._offsets = self._offsets[idx] + clist._lengths = self._lengths[idx] return clist elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: @@ -216,6 +227,6 @@ def load_compact_list(filename): content = np.load(filename) clist = CompactList() clist._data = content["data"] - clist._offsets = content["offsets"].tolist() - clist._lengths = content["lengths"].tolist() + clist._offsets = content["offsets"]#.tolist() + clist._lengths = content["lengths"]#.tolist() return clist diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 794f11ce95..8f139fad28 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -40,8 +40,8 @@ def test_creating_compactlist_from_list(self): assert_equal(len(clist._lengths), len(data)) assert_equal(clist._data.shape[0], sum(lengths)) assert_equal(clist._data.shape[1], 3) - assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) - assert_equal(clist._lengths, lengths) + assert_array_equal(clist._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) + assert_array_equal(clist._lengths, lengths) assert_equal(clist.shape, data[0].shape[1:]) # Empty list @@ -61,8 +61,8 @@ def test_creating_compactlist_from_list(self): assert_equal(len(clist._lengths), len(data)) assert_equal(clist._data.shape[0], sum(lengths)) assert_equal(clist._data.shape[1], 3) - assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) - assert_equal(clist._lengths, lengths) + assert_array_equal(clist._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) + assert_array_equal(clist._lengths, lengths) assert_equal(clist.shape, data[0].shape[1:]) CompactList.BUFFER_SIZE = old_buffer_size @@ -78,8 +78,8 @@ def test_creating_compactlist_from_generator(self): assert_equal(len(clist._lengths), len(data)) assert_equal(clist._data.shape[0], sum(lengths)) assert_equal(clist._data.shape[1], 3) - assert_equal(clist._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) - assert_equal(clist._lengths, lengths) + assert_array_equal(clist._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) + assert_array_equal(clist._lengths, lengths) assert_equal(clist.shape, data[0].shape[1:]) # Already consumed generator @@ -102,8 +102,8 @@ def test_creating_compactlist_from_compact_list(self): assert_equal(len(clist2._lengths), len(data)) assert_equal(clist2._data.shape[0], sum(lengths)) assert_equal(clist2._data.shape[1], 3) - assert_equal(clist2._offsets, [0] + np.cumsum(lengths)[:-1].tolist()) - assert_equal(clist2._lengths, lengths) + assert_array_equal(clist2._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) + assert_array_equal(clist2._lengths, lengths) assert_equal(clist2.shape, data[0].shape[1:]) def test_compactlist_iter(self): @@ -133,6 +133,10 @@ def test_compactlist_copy(self): assert_true(clist._data.shape[0] < self.clist._data.shape[0]) assert_true(len(clist) < len(self.clist)) assert_true(clist._data is not self.clist._data) + assert_array_equal(clist._lengths, self.clist[::2]._lengths) + assert_array_equal(clist._offsets, + np.cumsum(np.r_[0, self.clist[::2]._lengths])[:-1]) + assert_arrays_equal(clist, self.clist[::2]) def test_compactlist_append(self): # Maybe not necessary if `self.setUp` is always called before a @@ -178,7 +182,7 @@ def test_compactlist_extend(self): assert_array_equal(clist._offsets[-len(new_data):], len(self.clist._data) + np.cumsum([0] + lengths[:-1])) - assert_equal(clist._lengths[-len(new_data):], lengths) + assert_array_equal(clist._lengths[-len(new_data):], lengths) assert_array_equal(clist._data[-sum(lengths):], np.concatenate(new_data, axis=0)) @@ -188,9 +192,9 @@ def test_compactlist_extend(self): clist.extend(new_clist) assert_equal(len(clist), len(self.clist)+len(new_clist)) assert_array_equal(clist._offsets[-len(new_clist):], - len(self.clist._data) + np.cumsum([0] + lengths[:-1])) + len(self.clist._data) + np.cumsum(np.r_[0, lengths[:-1]])) - assert_equal(clist._lengths[-len(new_clist):], lengths) + assert_array_equal(clist._lengths[-len(new_clist):], lengths) assert_array_equal(clist._data[-sum(lengths):], new_clist._data) # Extend with another `CompactList` object that is a view (e.g. been sliced). @@ -201,9 +205,9 @@ def test_compactlist_extend(self): assert_equal(len(clist), len(self.clist)+len(new_clist)) assert_equal(len(clist._data), len(self.clist._data)+sum(new_clist._lengths)) assert_array_equal(clist._offsets[-len(new_clist):], - len(self.clist._data) + np.cumsum([0] + new_clist._lengths[:-1])) + len(self.clist._data) + np.cumsum(np.r_[0, new_clist._lengths[:-1]])) - assert_equal(clist._lengths[-len(new_clist):], lengths[::2]) + assert_array_equal(clist._lengths[-len(new_clist):], lengths[::2]) assert_array_equal(clist._data[-sum(new_clist._lengths):], new_clist.copy()._data) assert_arrays_equal(clist[-len(new_clist):], new_clist) @@ -216,7 +220,6 @@ def test_compactlist_extend(self): assert_array_equal(clist._lengths, new_clist._lengths) assert_array_equal(clist._data, new_clist._data) - def test_compactlist_getitem(self): # Get one item for i, e in enumerate(self.clist): @@ -248,9 +251,9 @@ def test_compactlist_getitem(self): assert_true(clist_view is not self.clist) assert_true(clist_view._data is self.clist._data) assert_array_equal(clist_view._offsets, - np.asarray(self.clist._offsets)[idx]) + self.clist._offsets[idx]) assert_array_equal(clist_view._lengths, - np.asarray(self.clist._lengths)[idx]) + self.clist._lengths[idx]) assert_array_equal(clist_view[0], self.clist[1]) assert_array_equal(clist_view[1], self.clist[2]) assert_array_equal(clist_view[2], self.clist[4]) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index b96ce08c6c..424e1587de 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -528,6 +528,8 @@ def _create_compactlist_from_generator(cls, gen): properties = np.empty((cls.READ_BUFFER_SIZE, props.shape[0]), dtype=props.dtype) offset = 0 + offsets = [] + lengths = [] for i, (pts, scals, props) in enumerate(gen): pts = np.asarray(pts) scals = np.asarray(scals) @@ -547,8 +549,8 @@ def _create_compactlist_from_generator(cls, gen): streamlines._data.resize((len(streamlines._data) + len(pts)+cls.READ_BUFFER_SIZE, pts.shape[1])) scalars._data.resize((len(scalars._data) + len(scals)+cls.READ_BUFFER_SIZE, scals.shape[1])) - streamlines._offsets.append(offset) - streamlines._lengths.append(len(pts)) + offsets.append(offset) + lengths.append(len(pts)) streamlines._data[offset:offset+len(pts)] = pts scalars._data[offset:offset+len(scals)] = scals @@ -559,6 +561,9 @@ def _create_compactlist_from_generator(cls, gen): properties[i] = props + streamlines._offsets = np.asarray(offsets) + streamlines._lengths = np.asarray(lengths) + # Clear unused memory. streamlines._data.resize((offset, pts.shape[1])) From 2e2abe7cf67aff2327a56cc0894192ab37a33fa8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 4 Feb 2016 19:18:52 -0500 Subject: [PATCH 070/135] [ENH] CompactList now support advanced indexing with numpy array of integers --- nibabel/streamlines/compact_list.py | 4 ++-- nibabel/streamlines/tests/test_compact_list.py | 18 +++++++++++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index edfa153de1..8dbeb5d976 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -173,14 +173,14 @@ def __getitem__(self, idx): start = self._offsets[idx] return self._data[start:start+self._lengths[idx]] - elif isinstance(idx, slice): + elif isinstance(idx, slice) or isinstance(idx, list): clist = CompactList() clist._data = self._data clist._offsets = self._offsets[idx] clist._lengths = self._lengths[idx] return clist - elif isinstance(idx, list): + elif isinstance(idx, np.ndarray) and np.issubdtype(idx.dtype, np.integer): clist = CompactList() clist._data = self._data clist._offsets = self._offsets[idx] diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 8f139fad28..67fe92475a 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -226,15 +226,27 @@ def test_compactlist_getitem(self): assert_array_equal(self.clist[i], e) # Get multiple items (this will create a view). - clist_view = self.clist[list(range(len(self.clist)))] + indices = list(range(len(self.clist))) + clist_view = self.clist[indices] assert_true(clist_view is not self.clist) assert_true(clist_view._data is self.clist._data) assert_true(clist_view._offsets is not self.clist._offsets) assert_true(clist_view._lengths is not self.clist._lengths) assert_array_equal(clist_view._offsets, self.clist._offsets) assert_array_equal(clist_view._lengths, self.clist._lengths) - for e1, e2 in zip_longest(clist_view, self.clist): - assert_array_equal(e1, e2) + assert_arrays_equal(clist_view, self.clist) + + # Get multiple items using ndarray of data type. + for dtype in [np.int8, np.int16, np.int32, np.int64]: + clist_view = self.clist[np.array(indices, dtype=dtype)] + assert_true(clist_view is not self.clist) + assert_true(clist_view._data is self.clist._data) + assert_true(clist_view._offsets is not self.clist._offsets) + assert_true(clist_view._lengths is not self.clist._lengths) + assert_array_equal(clist_view._offsets, self.clist._offsets) + assert_array_equal(clist_view._lengths, self.clist._lengths) + for e1, e2 in zip_longest(clist_view, self.clist): + assert_array_equal(e1, e2) # Get slice (this will create a view). clist_view = self.clist[::2] From d459e805f1bcd8b40ccb474c7e56fc33c57ab7b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 9 Feb 2016 09:37:10 -0500 Subject: [PATCH 071/135] In the doc, use :class: for a nicer markup with a link. --- nibabel/streamlines/__init__.py | 16 +++----- nibabel/streamlines/compact_list.py | 18 ++++----- nibabel/streamlines/tractogram.py | 55 +++++++++++++------------- nibabel/streamlines/tractogram_file.py | 2 +- nibabel/streamlines/trk.py | 8 ++-- nibabel/streamlines/utils.py | 2 +- 6 files changed, 49 insertions(+), 52 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 42c0619531..6b2ab2a0a7 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -45,9 +45,8 @@ def detect_format(fileobj): Returns ------- - tractogram_file : ``TractogramFile`` class - Returns an instance of a `TractogramFile` class containing data and - metadata of the tractogram contained from `fileobj`. + tractogram_file : :class:TractogramFile class + The class type guessed from the content of `fileobj`. ''' for format in FORMATS.values(): try: @@ -78,13 +77,10 @@ def load(fileobj, lazy_load=False, ref=None): Load streamlines in a lazy manner i.e. they will not be kept in memory. - ref : filename | `Nifti1Image` object | 2D array (4,4) (optional) - Reference space where streamlines will live in `fileobj`. - Returns ------- - tractogram_file : ``TractogramFile`` - Returns an instance of a `TractogramFile` class containing data and + tractogram_file : :class:TractogramFile object + Returns an instance of a :class:TractogramFile containing data and metadata of the tractogram loaded from `fileobj`. ''' tractogram_file = detect_format(fileobj) @@ -100,7 +96,7 @@ def save(tractogram_file, filename): Parameters ---------- - tractogram_file : ``TractogramFile`` object + tractogram_file : :class:TractogramFile object Tractogram to be saved on disk. filename : str @@ -115,7 +111,7 @@ def save_tractogram(tractogram, filename, **kwargs): Parameters ---------- - tractogram : ``Tractogram`` object + tractogram : :class:Tractogram object Tractogram to be saved. filename : str diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 8dbeb5d976..3ea8d04cb6 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -13,12 +13,12 @@ def __init__(self, iterable=None): Parameters ---------- iterable : iterable (optional) - If specified, create a ``CompactList`` object initialized from - iterable's items. Otherwise, create an empty ``CompactList``. + If specified, create a :class:CompactList object initialized from + iterable's items, otherwise it will be empty. Notes ----- - If `iterable` is a ``CompactList`` object, a view is returned and no + If `iterable` is a :class:CompactList object, a view is returned and no memory is allocated. For an actual copy use the `.copy()` method. """ # Create new empty `CompactList` object. @@ -103,7 +103,7 @@ def extend(self, elements): Parameters ---------- - elements : list of ndarrays, ``CompactList`` object + elements : list of ndarrays, :class:CompactList object Elements to append. The shape must match already inserted elements shape except for the first dimension. @@ -134,7 +134,7 @@ def extend(self, elements): self._offsets = np.r_[self._offsets, np.cumsum([next_offset] + lengths)[:-1]] def copy(self): - """ Creates a copy of this ``CompactList`` object. """ + """ Creates a copy of this :class:CompactList object. """ # We do not simply deepcopy this object since we might have a chance # to use less memory. For example, if the compact list being copied # is the result of a slicing operation on a compact list. @@ -215,7 +215,7 @@ def __repr__(self): def save_compact_list(filename, clist): - """ Saves a `CompactList` object to a .npz file. """ + """ Saves a :class:CompactList object to a .npz file. """ np.savez(filename, data=clist._data, offsets=clist._offsets, @@ -223,10 +223,10 @@ def save_compact_list(filename, clist): def load_compact_list(filename): - """ Loads a `CompactList` object from a .npz file. """ + """ Loads a :class:CompactList object from a .npz file. """ content = np.load(filename) clist = CompactList() clist._data = content["data"] - clist._offsets = content["offsets"]#.tolist() - clist._lengths = content["lengths"]#.tolist() + clist._offsets = content["offsets"] + clist._lengths = content["lengths"] return clist diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 8568fa4a21..97c0a7c537 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -14,7 +14,7 @@ class UsageWarning(Warning): class TractogramItem(object): """ Class containing information about one streamline. - ``TractogramItem`` objects have three main properties: `streamline`, + :class:TractogramItem objects have three main properties: `streamline`, `data_for_streamline`, and `data_for_points`. Parameters @@ -42,11 +42,12 @@ def __len__(self): class Tractogram(object): """ Class containing information about streamlines. - Tractogram objects have three main properties: ``streamlines`` + Tractogram objects have three main properties: `streamlines`, + `data_per_streamline` and `data_per_point`. Attributes ---------- - affine_to_rasmm : 2D array (4,4) + affine_to_rasmm : 2D ndarray (4,4) Affine that brings the streamlines back to *RAS+* and *mm* space where coordinate (0,0,0) refers to the center of the voxel. @@ -102,7 +103,7 @@ def __setitem__(self, key, value): self.store[key] = value class DataPerPointDict(DataDict): - """ Internal dictionary that makes sure data are `CompactList`. """ + """ Internal dictionary making sure data are :class:CompactList objects. """ def __setitem__(self, key, value): value = CompactList(value) @@ -130,12 +131,12 @@ def __init__(self, streamlines=None, data_per_streamline : dict of list of ndarray of shape (P,) (optional) Sequence of T ndarrays of shape (P,) where T is the number of - streamlines defined by ``streamlines``, P is the number of + streamlines defined by `streamlines`, P is the number of properties associated to each streamline. data_per_point : dict of list of ndarray of shape (Nt, M) (optional) Sequence of T ndarrays of shape (Nt, M) where T is the number - of streamlines defined by ``streamlines``, Nt is the number of + of streamlines defined by `streamlines`, Nt is the number of points for a particular streamline t and M is the number of scalars associated to each point (excluding the three coordinates). @@ -200,7 +201,7 @@ def __len__(self): return len(self.streamlines) def copy(self): - """ Returns a copy of this `Tractogram` object. """ + """ Returns a copy of this :class:Tractogram object. """ data_per_streamline = {} for key in self.data_per_streamline: data_per_streamline[key] = self.data_per_streamline[key].copy() @@ -228,11 +229,11 @@ def apply_affine(self, affine, lazy=False): Returns ------- - tractogram : ``Tractogram`` or ``LazyTractogram`` object + tractogram : :class:Tractogram or :class:LazyTractogram object Tractogram where the streamlines have been transformed according to the given affine transformation. If the `lazy` option is true, - it returns a ``LazyTractogram`` object, otherwise it returns a - reference to this ``Tractogram`` object with updated streamlines. + it returns a :class:LazyTractogram object, otherwise it returns a + reference to this :class:Tractogram object with updated streamlines. """ if lazy: @@ -258,15 +259,15 @@ def apply_affine(self, affine, lazy=False): class LazyTractogram(Tractogram): ''' Class containing information about streamlines. - Tractogram objects have four main properties: ``header``, ``streamlines``, - ``scalars`` and ``properties``. Tractogram objects are iterable and - produce tuple of ``streamlines``, ``scalars`` and ``properties`` for each + Tractogram objects have four main properties: `header`, `streamlines`, + `scalars` and `properties`. Tractogram objects are iterable and + produce tuple of `streamlines`, `scalars` and `properties` for each streamline. Notes ----- - If provided, ``scalars`` and ``properties`` must yield the same number of - values as ``streamlines``. + If provided, `scalars` and `properties` must yield the same number of + values as `streamlines`. ''' class LazyDict(collections.MutableMapping): @@ -331,16 +332,16 @@ def __init__(self, streamlines=None, @classmethod def from_tractogram(cls, tractogram): - ''' Creates a ``LazyTractogram`` object from a ``Tractogram`` object. + ''' Creates a :class:LazyTractogram object from a :class:Tractogram object. Parameters ---------- - tractogram : ``Tractgogram`` object - Tractogram from which to create a ``LazyTractogram`` object. + tractogram : :class:Tractgogram object + Tractogram from which to create a :class:LazyTractogram object. Returns ------- - lazy_tractogram : ``LazyTractogram`` object + lazy_tractogram : :class:LazyTractogram object New lazy tractogram. ''' @@ -362,19 +363,19 @@ def from_tractogram(cls, tractogram): @classmethod def create_from(cls, data_func): - ''' Creates a ``LazyTractogram`` from a coroutine yielding - ``TractogramItem`` objects. + ''' Creates a :class:LazyTractogram from a coroutine yielding + :class:TractogramItem objects. Parameters ---------- - data_func : coroutine yielding ``TractogramItem`` objects + data_func : coroutine yielding :class:TractogramItem objects A function that whenever it is called starts yielding - ``TractogramItem`` objects that should be part of this + :class:TractogramItem objects that should be part of this LazyTractogram. Returns ------- - lazy_tractogram : ``LazyTractogram`` object + lazy_tractogram : :class:LazyTractogram object New lazy tractogram. ''' @@ -508,7 +509,7 @@ def __len__(self): return self._nb_streamlines def copy(self): - """ Returns a copy of this `LazyTractogram` object. """ + """ Returns a copy of this :class:LazyTractogram object. """ tractogram = LazyTractogram(self._streamlines, self._data_per_streamline, self._data_per_point) @@ -530,8 +531,8 @@ def apply_affine(self, affine): Returns ------- - lazy_tractogram : ``LazyTractogram`` object - Reference to this instance of ``LazyTractogram``. + lazy_tractogram : :class:LazyTractogram object + Reference to this instance of :class:LazyTractogram. """ # Update the affine that will be applied when returning streamlines. diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index 9eeb2a303c..e0a64e9d02 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -109,7 +109,7 @@ def load(cls, fileobj, lazy_load=True): Returns ------- - tractogram_file : ``TractogramFile`` object + tractogram_file : :class:TractogramFile object Returns an object containing tractogram data and header information. ''' diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 424e1587de..a22a698c16 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -443,10 +443,10 @@ def __init__(self, tractogram, header=None): """ Parameters ---------- - tractogram : ``Tractogram`` object - Tractogram that will be contained in this ``TrkFile``. + tractogram : :class:Tractogram object + Tractogram that will be contained in this :class:TrkFile. - header : ``TractogramHeader`` file (optional) + header : dict (optional) Metadata associated to this tractogram file. Notes @@ -602,7 +602,7 @@ def load(cls, fileobj, lazy_load=False): Returns ------- - trk_file : ``TrkFile`` object + trk_file : :class:TrkFile object Returns an object containing tractogram data and header information. diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 7bbbe1ef8d..0fa1e4f47e 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -10,7 +10,7 @@ def get_affine_from_reference(ref): Parameter --------- - ref : filename | `Nifti1Image` object | 2D array (4,4) + ref : filename | :class:Nifti1Image object | 2D array (4,4) Reference space where streamlines live in `fileobj`. Returns From a641deb552d2424b8aefd385702295539a8cf6b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 9 Feb 2016 11:00:56 -0500 Subject: [PATCH 072/135] Fused save and save_tractogram functions. --- nibabel/streamlines/__init__.py | 76 ++++++++++--------- nibabel/streamlines/tests/test_streamlines.py | 16 ++-- 2 files changed, 46 insertions(+), 46 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 6b2ab2a0a7..2550373865 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -4,6 +4,7 @@ from .header import Field from .compact_list import CompactList from .tractogram import Tractogram, LazyTractogram +from .tractogram_file import TractogramFile from .trk import TrkFile #from .tck import TckFile @@ -17,7 +18,7 @@ def is_supported(fileobj): - ''' Checks if the file-like object if supported by NiBabel. + """ Checks if the file-like object if supported by NiBabel. Parameters ---------- @@ -29,12 +30,13 @@ def is_supported(fileobj): Returns ------- is_supported : boolean - ''' + + """ return detect_format(fileobj) is not None def detect_format(fileobj): - ''' Returns the StreamlinesFile object guessed from the file-like object. + """ Returns the StreamlinesFile object guessed from the file-like object. Parameters ---------- @@ -45,9 +47,10 @@ def detect_format(fileobj): Returns ------- - tractogram_file : :class:TractogramFile class + tractogram_file : :class:`TractogramFile` class The class type guessed from the content of `fileobj`. - ''' + + """ for format in FORMATS.values(): try: if format.is_correct_format(fileobj): @@ -64,7 +67,7 @@ def detect_format(fileobj): def load(fileobj, lazy_load=False, ref=None): - ''' Loads streamlines from a file-like object in voxel space. + """ Loads streamlines in *RAS+* and *mm* space from a file-like object. Parameters ---------- @@ -72,17 +75,21 @@ def load(fileobj, lazy_load=False, ref=None): If string, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the beginning of the streamlines file's header). - - lazy_load : boolean (optional) - Load streamlines in a lazy manner i.e. they will not be kept - in memory. + lazy_load : {False, True}, optional + If True, load streamlines in a lazy manner i.e. they will not be kept + in memory. Otherwise, load all streamlines in memory. Returns ------- tractogram_file : :class:TractogramFile object Returns an instance of a :class:TractogramFile containing data and metadata of the tractogram loaded from `fileobj`. - ''' + + Notes + ----- + The streamline coordinate (0,0,0) refers to the center of the voxel. + + """ tractogram_file = detect_format(fileobj) if tractogram_file is None: @@ -91,37 +98,32 @@ def load(fileobj, lazy_load=False, ref=None): return tractogram_file.load(fileobj, lazy_load=lazy_load) -def save(tractogram_file, filename): - ''' Saves a tractogram to a file. +def save(tractogram, filename, **kwargs): + """ Saves a tractogram to a file. Parameters ---------- - tractogram_file : :class:TractogramFile object - Tractogram to be saved on disk. - + tractogram : :class:`Tractogram` object or :class:`TractogramFile` object + If :class:`Tractogram` object, the file format will be guessed from + `filename` and a :class:`TractogramFile` object will be created using + provided keyword arguments. + If :class:`TractogramFile` object, the file format is known and will + be used to save its content to `filename`. filename : str - Name of the file where the tractogram will be saved. The format will - be guessed from `filename`. - ''' - tractogram_file.save(filename) + Name of the file where the tractogram will be saved. + \*\*kwargs : keyword arguments + Keyword arguments passed to :class:`TractogramFile` constructor. + """ + tractogram_file = tractogram + if isinstance(tractogram, Tractogram): + # We have to guess the file format. + tractogram_file_class = detect_format(filename) -def save_tractogram(tractogram, filename, **kwargs): - ''' Saves a tractogram to a file. + if tractogram_file_class is None: + msg = "Unknown tractogram file format: '{}'".format(filename) + raise ValueError(msg) - Parameters - ---------- - tractogram : :class:Tractogram object - Tractogram to be saved. + tractogram_file = tractogram_file_class(tractogram, **kwargs) - filename : str - Name of the file where the tractogram will be saved. The format will - be guessed from `filename`. - ''' - tractogram_file_class = detect_format(filename) - - if tractogram_file_class is None: - raise ValueError("Unknown tractogram file format: '{}'".format(filename)) - - tractogram_file = tractogram_file_class(tractogram, **kwargs) - save(tractogram_file, filename) + tractogram_file.save(filename) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 10eb415748..84b7b5c4b3 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -208,7 +208,7 @@ def test_save_empty_file(self): for ext, cls in nib.streamlines.FORMATS.items(): with InTemporaryDirectory(): with open('streamlines' + ext, 'w+b') as f: - nib.streamlines.save_tractogram(tractogram, f.name) + nib.streamlines.save(tractogram, f.name) tfile = nib.streamlines.load(f, lazy_load=False) assert_tractogram_equal(tfile.tractogram, tractogram) @@ -217,7 +217,7 @@ def test_save_simple_file(self): for ext, cls in nib.streamlines.FORMATS.items(): with InTemporaryDirectory(): with open('streamlines' + ext, 'w+b') as f: - nib.streamlines.save_tractogram(tractogram, f.name) + nib.streamlines.save(tractogram, f.name) tfile = nib.streamlines.load(f, lazy_load=False) assert_tractogram_equal(tfile.tractogram, tractogram) @@ -231,12 +231,11 @@ def test_save_complex_file(self): with open('streamlines' + ext, 'w+b') as f: with clear_and_catch_warnings(record=True, modules=[trk]) as w: - nib.streamlines.save_tractogram(complex_tractogram, - f.name) + nib.streamlines.save(complex_tractogram, f.name) - # If streamlines format does not support saving data per - # point or data per streamline, a warning message should - # be issued. + # If streamlines format does not support saving data + # per point or data per streamline, a warning message + # should be issued. if not (cls.support_data_per_point() and cls.support_data_per_streamline()): assert_equal(len(w), 1) @@ -257,5 +256,4 @@ def test_load_unknown_format(self): assert_raises(ValueError, nib.streamlines.load, "") def test_save_unknown_format(self): - assert_raises(ValueError, - nib.streamlines.save_tractogram, Tractogram(), "") + assert_raises(ValueError, nib.streamlines.save, Tractogram(), "") From 015ba0fcb82700934e622eac49087901a46c6b65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 9 Feb 2016 11:01:08 -0500 Subject: [PATCH 073/135] Fixed DOC --- nibabel/streamlines/compact_list.py | 23 +++++--- nibabel/streamlines/tractogram.py | 75 ++++++++++++++------------ nibabel/streamlines/tractogram_file.py | 37 +++++++------ nibabel/streamlines/trk.py | 63 ++++++++++++---------- nibabel/streamlines/utils.py | 16 ++++-- 5 files changed, 124 insertions(+), 90 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 3ea8d04cb6..fb2a9de14b 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -13,12 +13,12 @@ def __init__(self, iterable=None): Parameters ---------- iterable : iterable (optional) - If specified, create a :class:CompactList object initialized from + If specified, create a :class:`CompactList` object initialized from iterable's items, otherwise it will be empty. Notes ----- - If `iterable` is a :class:CompactList object, a view is returned and no + If `iterable` is a :class:`CompactList` object, a view is returned and no memory is allocated. For an actual copy use the `.copy()` method. """ # Create new empty `CompactList` object. @@ -103,9 +103,16 @@ def extend(self, elements): Parameters ---------- - elements : list of ndarrays, :class:CompactList object - Elements to append. The shape must match already inserted elements - shape except for the first dimension. + elements : list of ndarrays or :class:`CompactList` object + If list of ndarrays, each ndarray will be concatenated along the + first dimension then appended to the data of this CompactList. + If :class:`CompactList` object, its data are simply appended to + the data of this CompactList. + + Notes + ----- + The shape of the elements to be added must match the one of the data + of this CompactList except for the first dimension. """ if self._data.ndim == 0: @@ -134,7 +141,7 @@ def extend(self, elements): self._offsets = np.r_[self._offsets, np.cumsum([next_offset] + lengths)[:-1]] def copy(self): - """ Creates a copy of this :class:CompactList object. """ + """ Creates a copy of this :class:`CompactList` object. """ # We do not simply deepcopy this object since we might have a chance # to use less memory. For example, if the compact list being copied # is the result of a slicing operation on a compact list. @@ -215,7 +222,7 @@ def __repr__(self): def save_compact_list(filename, clist): - """ Saves a :class:CompactList object to a .npz file. """ + """ Saves a :class:`CompactList` object to a .npz file. """ np.savez(filename, data=clist._data, offsets=clist._offsets, @@ -223,7 +230,7 @@ def save_compact_list(filename, clist): def load_compact_list(filename): - """ Loads a :class:CompactList object from a .npz file. """ + """ Loads a :class:`CompactList` object from a .npz file. """ content = np.load(filename) clist = CompactList() clist._data = content["data"] diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 97c0a7c537..ea0fa4e4a4 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -14,18 +14,24 @@ class UsageWarning(Warning): class TractogramItem(object): """ Class containing information about one streamline. - :class:TractogramItem objects have three main properties: `streamline`, + :class:`TractogramItem` objects have three main properties: `streamline`, `data_for_streamline`, and `data_for_points`. Parameters ---------- - streamline : ndarray of shape (N, 3) + streamline : ndarray shape (N, 3) Points of this streamline represented as an ndarray of shape (N, 3) where N is the number of points. - data_for_streamline : dict - + Dictionary containing some data associated to this particular + streamline. Each key `k` is mapped to a ndarray of shape (Pk,), where + `Pt` is the dimension of the data associated with key `k`. data_for_points : dict + Dictionary containing some data associated to each point of this + particular streamline. Each key `k` is mapped to a ndarray of + shape (Nt, Mk), where `Nt` is the number of points of this streamline + and `Mk` is the dimension of the data associated with key `k`. + """ def __init__(self, streamline, data_for_streamline, data_for_points): self.streamline = np.asarray(streamline) @@ -47,7 +53,7 @@ class Tractogram(object): Attributes ---------- - affine_to_rasmm : 2D ndarray (4,4) + affine_to_rasmm : ndarray shape (4, 4) Affine that brings the streamlines back to *RAS+* and *mm* space where coordinate (0,0,0) refers to the center of the voxel. @@ -81,7 +87,7 @@ def __len__(self): return len(self.store) class DataPerStreamlineDict(DataDict): - """ Internal dictionary that makes sure data are 2D ndarray. """ + """ Internal dictionary that makes sure data are 2D array. """ def __setitem__(self, key, value): value = np.asarray(value) @@ -91,7 +97,7 @@ def __setitem__(self, key, value): value.shape = ((len(value), 1)) if value.ndim != 2: - raise ValueError("data_per_streamline must be a 2D ndarray.") + raise ValueError("data_per_streamline must be a 2D array.") # We make sure there is the right amount of values # (i.e. same as the number of streamlines in the tractogram). @@ -103,7 +109,8 @@ def __setitem__(self, key, value): self.store[key] = value class DataPerPointDict(DataDict): - """ Internal dictionary making sure data are :class:CompactList objects. """ + """ Internal dictionary making sure data are :class:`CompactList` objects. + """ def __setitem__(self, key, value): value = CompactList(value) @@ -128,12 +135,10 @@ def __init__(self, streamlines=None, streamlines : list of ndarray of shape (Nt, 3) (optional) Sequence of T streamlines. One streamline is an ndarray of shape (Nt, 3) where Nt is the number of points of streamline t. - data_per_streamline : dict of list of ndarray of shape (P,) (optional) Sequence of T ndarrays of shape (P,) where T is the number of streamlines defined by `streamlines`, P is the number of properties associated to each streamline. - data_per_point : dict of list of ndarray of shape (Nt, M) (optional) Sequence of T ndarrays of shape (Nt, M) where T is the number of streamlines defined by `streamlines`, Nt is the number of @@ -201,7 +206,7 @@ def __len__(self): return len(self.streamlines) def copy(self): - """ Returns a copy of this :class:Tractogram object. """ + """ Returns a copy of this :class:`Tractogram` object. """ data_per_streamline = {} for key in self.data_per_streamline: data_per_streamline[key] = self.data_per_streamline[key].copy() @@ -224,16 +229,21 @@ def apply_affine(self, affine, lazy=False): Parameters ---------- - affine : 2D array (4,4) + affine : ndarray shape (4, 4) Transformation that will be applied to every streamline. + lazy_load : {False, True}, optional + If True, streamlines are *not* transformed in-place and a + :class:`LazyTractogram` object is returned. Otherwise, streamlines + are modified in-place. Returns ------- - tractogram : :class:Tractogram or :class:LazyTractogram object + tractogram : :class:`Tractogram` or :class:`LazyTractogram` object Tractogram where the streamlines have been transformed according to the given affine transformation. If the `lazy` option is true, - it returns a :class:LazyTractogram object, otherwise it returns a - reference to this :class:Tractogram object with updated streamlines. + it returns a :class:`LazyTractogram` object, otherwise it returns a + reference to this :class:`Tractogram` object with updated + streamlines. """ if lazy: @@ -257,7 +267,7 @@ def apply_affine(self, affine, lazy=False): class LazyTractogram(Tractogram): - ''' Class containing information about streamlines. + """ Class containing information about streamlines. Tractogram objects have four main properties: `header`, `streamlines`, `scalars` and `properties`. Tractogram objects are iterable and @@ -268,7 +278,8 @@ class LazyTractogram(Tractogram): ----- If provided, `scalars` and `properties` must yield the same number of values as `streamlines`. - ''' + + """ class LazyDict(collections.MutableMapping): """ Internal dictionary with lazy evaluations. """ @@ -309,12 +320,10 @@ def __init__(self, streamlines=None, streamlines : coroutine yielding ndarrays of shape (Nt,3) (optional) Function yielding streamlines. One streamline is an ndarray of shape (Nt,3) where Nt is the number of points of streamline t. - data_per_streamline : dict of coroutines yielding ndarrays of shape (P,) (optional) Function yielding properties for a particular streamline t. The properties are represented as an ndarray of shape (P,) where P is the number of properties associated to each streamline. - data_per_point : dict of coroutines yielding ndarrays of shape (Nt,M) (optional) Function yielding scalars for a particular streamline t. The scalars are represented as an ndarray of shape (Nt,M) where Nt @@ -332,19 +341,19 @@ def __init__(self, streamlines=None, @classmethod def from_tractogram(cls, tractogram): - ''' Creates a :class:LazyTractogram object from a :class:Tractogram object. + """ Creates a :class:`LazyTractogram` object from a :class:`Tractogram` object. Parameters ---------- - tractogram : :class:Tractgogram object - Tractogram from which to create a :class:LazyTractogram object. + tractogram : :class:`Tractgogram` object + Tractogram from which to create a :class:`LazyTractogram` object. Returns ------- - lazy_tractogram : :class:LazyTractogram object + lazy_tractogram : :class:`LazyTractogram` object New lazy tractogram. - ''' + """ data_per_streamline = {} for key, value in tractogram.data_per_streamline.items(): data_per_streamline[key] = lambda: value @@ -363,22 +372,22 @@ def from_tractogram(cls, tractogram): @classmethod def create_from(cls, data_func): - ''' Creates a :class:LazyTractogram from a coroutine yielding - :class:TractogramItem objects. + """ Creates a :class:`LazyTractogram` from a coroutine yielding + :class:`TractogramItem` objects. Parameters ---------- - data_func : coroutine yielding :class:TractogramItem objects + data_func : coroutine yielding :class:`TractogramItem` objects A function that whenever it is called starts yielding - :class:TractogramItem objects that should be part of this + :class:`TractogramItem` objects that should be part of this LazyTractogram. Returns ------- - lazy_tractogram : :class:LazyTractogram object + lazy_tractogram : :class:`LazyTractogram` object New lazy tractogram. - ''' + """ if not callable(data_func): raise TypeError("`data_func` must be a coroutine.") @@ -509,7 +518,7 @@ def __len__(self): return self._nb_streamlines def copy(self): - """ Returns a copy of this :class:LazyTractogram object. """ + """ Returns a copy of this :class:`LazyTractogram` object. """ tractogram = LazyTractogram(self._streamlines, self._data_per_streamline, self._data_per_point) @@ -531,8 +540,8 @@ def apply_affine(self, affine): Returns ------- - lazy_tractogram : :class:LazyTractogram object - Reference to this instance of :class:LazyTractogram. + lazy_tractogram : :class:`LazyTractogram` object + Reference to this instance of :class:`LazyTractogram`. """ # Update the affine that will be applied when returning streamlines. diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index e0a64e9d02..ab0b05467f 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -25,7 +25,7 @@ def __init__(self, callable): class TractogramFile(with_metaclass(ABCMeta)): - ''' Convenience class to encapsulate tractogram file format. ''' + """ Convenience class to encapsulate tractogram file format. """ def __init__(self, tractogram, header=None): self._tractogram = tractogram @@ -62,22 +62,23 @@ def get_affine(self): @abstractclassmethod def get_magic_number(cls): - ''' Returns streamlines file's magic number. ''' + """ Returns streamlines file's magic number. """ raise NotImplementedError() @abstractclassmethod def support_data_per_point(cls): - ''' Tells if this tractogram format supports saving data per point. ''' + """ Tells if this tractogram format supports saving data per point. """ raise NotImplementedError() @abstractclassmethod def support_data_per_streamline(cls): - ''' Tells if this tractogram format supports saving data per streamline. ''' + """ Tells if this tractogram format supports saving data per streamline. + """ raise NotImplementedError() @abstractclassmethod def is_correct_format(cls, fileobj): - ''' Checks if the file has the right streamlines file format. + """ Checks if the file has the right streamlines file format. Parameters ---------- @@ -88,14 +89,16 @@ def is_correct_format(cls, fileobj): Returns ------- - is_correct_format : boolean - Returns True if `fileobj` is in the right streamlines file format. - ''' + is_correct_format : {True, False} + Returns True if `fileobj` is in the right streamlines file format, + otherwise returns False. + + """ raise NotImplementedError() @abstractclassmethod def load(cls, fileobj, lazy_load=True): - ''' Loads streamlines from a file-like object. + """ Loads streamlines from a file-like object. Parameters ---------- @@ -103,26 +106,28 @@ def load(cls, fileobj, lazy_load=True): If string, a filename; otherwise an open file-like object pointing to a streamlines file (and ready to read from the beginning of the header). - lazy_load : boolean (optional) - Load streamlines in a lazy manner i.e. they will not be kept - in memory. For postprocessing speed, turn off this option. + lazy_load : {False, True}, optional + If True, load streamlines in a lazy manner i.e. they will not be + kept in memory. Otherwise, load all streamlines in memory. Returns ------- - tractogram_file : :class:TractogramFile object + tractogram_file : :class:`TractogramFile` object Returns an object containing tractogram data and header information. - ''' + + """ raise NotImplementedError() @abstractmethod def save(self, fileobj): - ''' Saves streamlines to a file-like object. + """ Saves streamlines to a file-like object. Parameters ---------- fileobj : string or file-like object If string, a filename; otherwise an open file-like object opened and ready to write. - ''' + + """ raise NotImplementedError() diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index a22a698c16..31fb3807cc 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -87,7 +87,7 @@ class TrkReader(object): - ''' Convenience class to encapsulate TRK file format. + """ Convenience class to encapsulate TRK file format. Parameters ---------- @@ -105,7 +105,8 @@ class TrkReader(object): Thus, streamlines are shifted of half a voxel on load and are shifted back on save. - ''' + + """ def __init__(self, fileobj): self.fileobj = fileobj @@ -222,7 +223,7 @@ def __iter__(self): class TrkWriter(object): @classmethod def create_empty_header(cls): - ''' Return an empty compliant TRK header. ''' + """ Return an empty compliant TRK header. """ header = np.zeros(1, dtype=header_2_dtype) # Default values @@ -421,7 +422,7 @@ def write(self, tractogram): class TrkFile(TractogramFile): - ''' Convenience class to encapsulate TRK file format. + """ Convenience class to encapsulate TRK file format. Note ---- @@ -432,7 +433,8 @@ class TrkFile(TractogramFile): Thus, streamlines are shifted of half a voxel on load and are shifted back on save. - ''' + + """ # Contants MAGIC_NUMBER = b"TRACK" @@ -443,10 +445,10 @@ def __init__(self, tractogram, header=None): """ Parameters ---------- - tractogram : :class:Tractogram object - Tractogram that will be contained in this :class:TrkFile. + tractogram : :class:`Tractogram` object + Tractogram that will be contained in this :class:`TrkFile`. - header : dict (optional) + header : dict, optional Metadata associated to this tractogram file. Notes @@ -463,22 +465,23 @@ def __init__(self, tractogram, header=None): @classmethod def get_magic_number(cls): - ''' Return TRK's magic number. ''' + """ Return TRK's magic number. """ return cls.MAGIC_NUMBER @classmethod def support_data_per_point(cls): - ''' Tells if this tractogram format supports saving data per point. ''' + """ Tells if this tractogram format supports saving data per point. """ return True @classmethod def support_data_per_streamline(cls): - ''' Tells if this tractogram format supports saving data per streamline. ''' + """ Tells if this tractogram format supports saving data per streamline. + """ return True @classmethod def is_correct_format(cls, fileobj): - ''' Check if the file is in TRK format. + """ Check if the file is in TRK format. Parameters ---------- @@ -489,9 +492,10 @@ def is_correct_format(cls, fileobj): Returns ------- - is_correct_format : boolean - Returns True if `fileobj` is in TRK format. - ''' + is_correct_format : {True, False} + Returns True if `fileobj` is compatible with TRK format, + otherwise returns False. + """ with Opener(fileobj) as f: magic_number = f.read(5) f.seek(-5, os.SEEK_CUR) @@ -587,7 +591,7 @@ def _create_compactlist_from_generator(cls, gen): @classmethod def load(cls, fileobj, lazy_load=False): - ''' Loads streamlines from a file-like object. + """ Loads streamlines from a file-like object. Parameters ---------- @@ -595,22 +599,23 @@ def load(cls, fileobj, lazy_load=False): If string, a filename; otherwise an open file-like object pointing to TRK file (and ready to read from the beginning of the TRK header). - - lazy_load : boolean (optional) - Load streamlines in a lazy manner i.e. they will not be kept - in memory. + lazy_load : {False, True}, optional + If True, load streamlines in a lazy manner i.e. they will not be + kept in memory. Otherwise, load all streamlines in memory. Returns ------- - trk_file : :class:TrkFile object + trk_file : :class:`TrkFile` object Returns an object containing tractogram data and header information. Notes ----- - Streamlines of the returned tractogram are assumed to be in RASmm - space where coordinate (0,0,0) refers to the center of the voxel. - ''' + Streamlines of the returned tractogram are assumed to be in *RAS* + and *mm* space where coordinate (0,0,0) refers to the center of the + voxel. + + """ trk_reader = TrkReader(fileobj) # TRK's streamlines are in 'voxelmm' space, we will compute the @@ -718,7 +723,7 @@ def _read(): return cls(tractogram, header=trk_reader.header) def save(self, fileobj): - ''' Saves tractogram to a file-like object using TRK format. + """ Saves tractogram to a file-like object using TRK format. Parameters ---------- @@ -726,12 +731,13 @@ def save(self, fileobj): If string, a filename; otherwise an open file-like object pointing to TRK file (and ready to read from the beginning of the TRK header data). - ''' + + """ trk_writer = TrkWriter(fileobj, self.header) trk_writer.write(self.tractogram) def __str__(self): - ''' Gets a formatted string of the header of a TRK file. + """ Gets a formatted string of the header of a TRK file. Parameters ---------- @@ -744,7 +750,8 @@ def __str__(self): ------- info : string Header information relevant to the TRK format. - ''' + + """ hdr = self.header info = "" diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 0fa1e4f47e..ce46c8f3eb 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -10,16 +10,22 @@ def get_affine_from_reference(ref): Parameter --------- - ref : filename | :class:Nifti1Image object | 2D array (4,4) - Reference space where streamlines live in `fileobj`. + ref : str or :class:`Nifti1Image` object or ndarray shape (4, 4) + If str then it's the filename of reference file that will be loaded + using :func:nibabel.load in order to obtain the affine. + If :class:`Nifti1Image` object then the affine is obtained from it. + If ndarray shape (4, 4) then it's the affine. Returns ------- - affine : 2D array (4,4) + affine : ndarray (4, 4) + Transformation matrix mapping voxel space to RAS+mm space. + """ if type(ref) is np.ndarray: if ref.shape != (4, 4): - raise ValueError("`ref` needs to be a numpy array with shape (4,4)!") + msg = "`ref` needs to be a numpy array with shape (4, 4)!" + raise ValueError(msg) return ref elif isinstance(ref, SpatialImage): @@ -30,6 +36,6 @@ def get_affine_from_reference(ref): def pop(iterable): - "Returns the next item from the iterable else None" + """ Returns the next item from the iterable else None. """ value = list(itertools.islice(iterable, 1)) return value[0] if len(value) > 0 else None From a8b2b6ead0a85d3838541b0714c1733cc6208467 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 9 Feb 2016 11:26:45 -0500 Subject: [PATCH 074/135] Renamed property shape for common_shape --- nibabel/streamlines/compact_list.py | 31 +++++++++---------- .../streamlines/tests/test_compact_list.py | 22 ++++++------- 2 files changed, 26 insertions(+), 27 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index fb2a9de14b..1f0ee96f9a 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -12,19 +12,18 @@ def __init__(self, iterable=None): """ Parameters ---------- - iterable : iterable (optional) - If specified, create a :class:`CompactList` object initialized from - iterable's items, otherwise it will be empty. + iterable : None or iterable of array-like objects or :class:`CompactList`, optional + If None, create an empty :class:`CompactList` object. + If iterable, create a :class:`CompactList` object initialized from + the iterable's items. + If :class:`CompactList`, create a view (no memory is allocated). + For an actual copy use :method:`CompactList.copy` instead. - Notes - ----- - If `iterable` is a :class:`CompactList` object, a view is returned and no - memory is allocated. For an actual copy use the `.copy()` method. """ # Create new empty `CompactList` object. self._data = np.array(0) - self._offsets = np.array([], dtype=int) - self._lengths = np.array([], dtype=int) + self._offsets = np.array([], dtype=np.intp) + self._lengths = np.array([], dtype=np.intp) if isinstance(iterable, CompactList): # Create a view. @@ -48,7 +47,7 @@ def __init__(self, iterable=None): # Resize needed, adding `len(e)` new items plus some buffer. nb_points = len(self._data) nb_points += len(e) + CompactList.BUFFER_SIZE - self._data.resize((nb_points,) + self.shape) + self._data.resize((nb_points,) + self.common_shape) offsets.append(offset) lengths.append(len(e)) @@ -60,13 +59,13 @@ def __init__(self, iterable=None): # Clear unused memory. if self._data.ndim != 0: - self._data.resize((offset,) + self.shape) + self._data.resize((offset,) + self.common_shape) @property - def shape(self): + def common_shape(self): """ Returns the matching shape of the elements in this compact list. """ if self._data.ndim == 0: - return None + return () return self._data.shape[1:] @@ -90,9 +89,9 @@ def append(self, element): self._lengths = np.array([len(element)]) return - if element.shape[1:] != self.shape: - raise ValueError("All dimensions, except the first one," - " must match exactly") + if element.shape[1:] != self.common_shape: + msg = "All dimensions, except the first one, must match exactly" + raise ValueError(msg) self._offsets = np.r_[self._offsets, len(self._data)] self._lengths = np.r_[self._lengths, len(element)] diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 67fe92475a..0ad461844c 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -27,7 +27,7 @@ def test_creating_empty_compactlist(self): assert_equal(len(clist._offsets), 0) assert_equal(len(clist._lengths), 0) assert_equal(clist._data.ndim, 0) - assert_true(clist.shape is None) + assert_true(clist.common_shape == ()) def test_creating_compactlist_from_list(self): rng = np.random.RandomState(42) @@ -42,7 +42,7 @@ def test_creating_compactlist_from_list(self): assert_equal(clist._data.shape[1], 3) assert_array_equal(clist._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) assert_array_equal(clist._lengths, lengths) - assert_equal(clist.shape, data[0].shape[1:]) + assert_equal(clist.common_shape, data[0].shape[1:]) # Empty list clist = CompactList([]) @@ -50,7 +50,7 @@ def test_creating_compactlist_from_list(self): assert_equal(len(clist._offsets), 0) assert_equal(len(clist._lengths), 0) assert_equal(clist._data.ndim, 0) - assert_true(clist.shape is None) + assert_true(clist.common_shape == ()) # Force CompactList constructor to use buffering. old_buffer_size = CompactList.BUFFER_SIZE @@ -63,7 +63,7 @@ def test_creating_compactlist_from_list(self): assert_equal(clist._data.shape[1], 3) assert_array_equal(clist._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) assert_array_equal(clist._lengths, lengths) - assert_equal(clist.shape, data[0].shape[1:]) + assert_equal(clist.common_shape, data[0].shape[1:]) CompactList.BUFFER_SIZE = old_buffer_size def test_creating_compactlist_from_generator(self): @@ -80,7 +80,7 @@ def test_creating_compactlist_from_generator(self): assert_equal(clist._data.shape[1], 3) assert_array_equal(clist._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) assert_array_equal(clist._lengths, lengths) - assert_equal(clist.shape, data[0].shape[1:]) + assert_equal(clist.common_shape, data[0].shape[1:]) # Already consumed generator clist = CompactList(gen) @@ -88,7 +88,7 @@ def test_creating_compactlist_from_generator(self): assert_equal(len(clist._offsets), 0) assert_equal(len(clist._lengths), 0) assert_equal(clist._data.ndim, 0) - assert_true(clist.shape is None) + assert_true(clist.common_shape == ()) def test_creating_compactlist_from_compact_list(self): rng = np.random.RandomState(42) @@ -104,7 +104,7 @@ def test_creating_compactlist_from_compact_list(self): assert_equal(clist2._data.shape[1], 3) assert_array_equal(clist2._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) assert_array_equal(clist2._lengths, lengths) - assert_equal(clist2.shape, data[0].shape[1:]) + assert_equal(clist2.common_shape, data[0].shape[1:]) def test_compactlist_iter(self): for e, d in zip(self.clist, self.data): @@ -124,7 +124,7 @@ def test_compactlist_copy(self): assert_array_equal(clist._lengths, self.clist._lengths) assert_true(clist._lengths is not self.clist._lengths) - assert_equal(clist.shape, self.clist.shape) + assert_equal(clist.common_shape, self.clist.common_shape) # When taking a copy of a `CompactList` generated by slicing. # Only needed data should be kept. @@ -144,7 +144,7 @@ def test_compactlist_append(self): clist = self.clist.copy() rng = np.random.RandomState(1234) - element = rng.rand(rng.randint(10, 50), *self.clist.shape) + element = rng.rand(rng.randint(10, 50), *self.clist.common_shape) clist.append(element) assert_equal(len(clist), len(self.clist)+1) assert_equal(clist._offsets[-1], len(self.clist._data)) @@ -166,7 +166,7 @@ def test_compactlist_append(self): assert_equal(clist._offsets[-1], 0) assert_equal(clist._lengths[-1], len(element)) assert_array_equal(clist._data, element) - assert_equal(clist.shape, shape) + assert_equal(clist.common_shape, shape) def test_compactlist_extend(self): # Maybe not necessary if `self.setUp` is always called before a @@ -174,7 +174,7 @@ def test_compactlist_extend(self): clist = self.clist.copy() rng = np.random.RandomState(1234) - shape = self.clist.shape + shape = self.clist.common_shape new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(10)] lengths = list(map(len, new_data)) clist.extend(new_data) From fbef9a5a37e8a85ea615d0a7774196d8f76176e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 9 Feb 2016 11:41:49 -0500 Subject: [PATCH 075/135] Moved functions save_compact_list and load_compact_list functions into CompactList as methods --- nibabel/streamlines/compact_list.py | 40 +++++++-------- .../streamlines/tests/test_compact_list.py | 49 +++++++++---------- 2 files changed, 44 insertions(+), 45 deletions(-) diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/compact_list.py index 1f0ee96f9a..fdcfc5555e 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/compact_list.py @@ -166,14 +166,15 @@ def __getitem__(self, idx): Parameters ---------- - idx : int, slice or list + idx : int or slice or list or ndarray of bool dtype or ndarray of int dtype Index of the element(s) to get. Returns ------- - ndarray object(s) + ndarray(s) When `idx` is an int, returns a single ndarray. - When `idx` is either a slice or a list, returns a list of ndarrays. + When `idx` is either a slice, a list or a ndarray, returns a list + of ndarrays. """ if isinstance(idx, int) or isinstance(idx, np.integer): start = self._offsets[idx] @@ -219,20 +220,19 @@ def __len__(self): def __repr__(self): return repr(list(self)) - -def save_compact_list(filename, clist): - """ Saves a :class:`CompactList` object to a .npz file. """ - np.savez(filename, - data=clist._data, - offsets=clist._offsets, - lengths=clist._lengths) - - -def load_compact_list(filename): - """ Loads a :class:`CompactList` object from a .npz file. """ - content = np.load(filename) - clist = CompactList() - clist._data = content["data"] - clist._offsets = content["offsets"] - clist._lengths = content["lengths"] - return clist + def save(self, filename): + """ Saves this :class:`CompactList` object to a .npz file. """ + np.savez(filename, + data=self._data, + offsets=self._offsets, + lengths=self._lengths) + + @classmethod + def from_filename(cls, filename): + """ Loads a :class:`CompactList` object from a .npz file. """ + content = np.load(filename) + clist = cls() + clist._data = content["data"] + clist._offsets = content["offsets"] + clist._lengths = content["lengths"] + return clist diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py index 0ad461844c..d0ef04951c 100644 --- a/nibabel/streamlines/tests/test_compact_list.py +++ b/nibabel/streamlines/tests/test_compact_list.py @@ -8,9 +8,7 @@ from numpy.testing import assert_array_equal from nibabel.externals.six.moves import zip, zip_longest -from ..compact_list import (CompactList, - load_compact_list, - save_compact_list) +from ..compact_list import CompactList class TestCompactList(unittest.TestCase): @@ -277,25 +275,26 @@ def test_compactlist_repr(self): # Test that calling repr on a CompactList object is not falling. repr(self.clist) - -def test_save_and_load_compact_list(): - - with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: - clist = CompactList() - save_compact_list(f, clist) - f.seek(0, os.SEEK_SET) - loaded_clist = load_compact_list(f) - assert_array_equal(loaded_clist._data, clist._data) - assert_array_equal(loaded_clist._offsets, clist._offsets) - assert_array_equal(loaded_clist._lengths, clist._lengths) - - with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - clist = CompactList(data) - save_compact_list(f, clist) - f.seek(0, os.SEEK_SET) - loaded_clist = load_compact_list(f) - assert_array_equal(loaded_clist._data, clist._data) - assert_array_equal(loaded_clist._offsets, clist._offsets) - assert_array_equal(loaded_clist._lengths, clist._lengths) + def test_save_and_load_compact_list(self): + + # Test saving and loading an empty CompactList. + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + clist = CompactList() + clist.save(f) + f.seek(0, os.SEEK_SET) + loaded_clist = CompactList.from_filename(f) + assert_array_equal(loaded_clist._data, clist._data) + assert_array_equal(loaded_clist._offsets, clist._offsets) + assert_array_equal(loaded_clist._lengths, clist._lengths) + + # Test saving and loading a CompactList. + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + clist = CompactList(data) + clist.save(f) + f.seek(0, os.SEEK_SET) + loaded_clist = CompactList.from_filename(f) + assert_array_equal(loaded_clist._data, clist._data) + assert_array_equal(loaded_clist._offsets, clist._offsets) + assert_array_equal(loaded_clist._lengths, clist._lengths) From 646ac99985b67a38dd1eb173210071c78b232a65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 9 Feb 2016 11:56:04 -0500 Subject: [PATCH 076/135] Addressed @matthew-brett's comments. --- nibabel/streamlines/tractogram.py | 3 +-- nibabel/streamlines/tractogram_file.py | 3 +-- nibabel/streamlines/trk.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index ea0fa4e4a4..32f25fc04a 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -109,8 +109,7 @@ def __setitem__(self, key, value): self.store[key] = value class DataPerPointDict(DataDict): - """ Internal dictionary making sure data are :class:`CompactList` objects. - """ + """ Internal dictionary making sure data are :class:`CompactList` objects. """ def __setitem__(self, key, value): value = CompactList(value) diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index ab0b05467f..cdd0463d5d 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -72,8 +72,7 @@ def support_data_per_point(cls): @abstractclassmethod def support_data_per_streamline(cls): - """ Tells if this tractogram format supports saving data per streamline. - """ + """ Tells if this tractogram format supports saving data per streamline. """ raise NotImplementedError() @abstractclassmethod diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 31fb3807cc..259e5f4b2f 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -475,8 +475,7 @@ def support_data_per_point(cls): @classmethod def support_data_per_streamline(cls): - """ Tells if this tractogram format supports saving data per streamline. - """ + """ Tells if this tractogram format supports saving data per streamline. """ return True @classmethod From d34c45261f25820c0c5e72cbdb65c74049ec5f72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 9 Feb 2016 17:16:58 -0500 Subject: [PATCH 077/135] RFG: renamed CompactList to ArraySequence --- nibabel/streamlines/__init__.py | 2 +- .../{compact_list.py => array_sequence.py} | 112 ++++--- .../streamlines/tests/test_array_sequence.py | 300 ++++++++++++++++++ .../streamlines/tests/test_compact_list.py | 300 ------------------ nibabel/streamlines/tractogram.py | 8 +- nibabel/streamlines/trk.py | 22 +- 6 files changed, 381 insertions(+), 363 deletions(-) rename nibabel/streamlines/{compact_list.py => array_sequence.py} (62%) create mode 100644 nibabel/streamlines/tests/test_array_sequence.py delete mode 100644 nibabel/streamlines/tests/test_compact_list.py diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 2550373865..6a1d8a87cd 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -2,7 +2,7 @@ from ..externals.six import string_types from .header import Field -from .compact_list import CompactList +from .array_sequence import ArraySequence from .tractogram import Tractogram, LazyTractogram from .tractogram_file import TractogramFile diff --git a/nibabel/streamlines/compact_list.py b/nibabel/streamlines/array_sequence.py similarity index 62% rename from nibabel/streamlines/compact_list.py rename to nibabel/streamlines/array_sequence.py index fdcfc5555e..2f3a82739d 100644 --- a/nibabel/streamlines/compact_list.py +++ b/nibabel/streamlines/array_sequence.py @@ -1,9 +1,18 @@ import numpy as np -class CompactList(object): - """ Class for compacting list of ndarrays with matching shape except for - the first dimension. +class ArraySequence(object): + """ Sequence of ndarrays having variable first dimension sizes. + + This is a container allowing to store multiple ndarrays where each ndarray + might have different first dimension size but a *common* size for the + remaining dimensions. + + More generally, an instance of :class:`ArraySequence` of length $N$ is + compoosed of $N$ ndarrays of shape $(d_n, d_2, ... d_D)$ where + $n \in [1,N]$, $d_n$ might vary from one ndarray to another and + $d_2, ..., d_D)$ have to be the same for every ndarray. + """ BUFFER_SIZE = 87382*4 # About 4 Mb if item shape is 3 (e.g. 3D points). @@ -12,20 +21,20 @@ def __init__(self, iterable=None): """ Parameters ---------- - iterable : None or iterable of array-like objects or :class:`CompactList`, optional - If None, create an empty :class:`CompactList` object. - If iterable, create a :class:`CompactList` object initialized from - the iterable's items. - If :class:`CompactList`, create a view (no memory is allocated). - For an actual copy use :method:`CompactList.copy` instead. + iterable : None or iterable or :class:`ArraySequence`, optional + If None, create an empty :class:`ArraySequence` object. + If iterable, create a :class:`ArraySequence` object initialized + from array-like objects yielded by the iterable. + If :class:`ArraySequence`, create a view (no memory is allocated). + For an actual copy use :meth:`.copy` instead. """ - # Create new empty `CompactList` object. + # Create new empty `ArraySequence` object. self._data = np.array(0) self._offsets = np.array([], dtype=np.intp) self._lengths = np.array([], dtype=np.intp) - if isinstance(iterable, CompactList): + if isinstance(iterable, ArraySequence): # Create a view. self._data = iterable._data self._offsets = iterable._offsets @@ -34,19 +43,19 @@ def __init__(self, iterable=None): elif iterable is not None: offsets = [] lengths = [] - # Initialize the `CompactList` object from iterable's item. + # Initialize the `ArraySequence` object from iterable's item. offset = 0 for i, e in enumerate(iterable): e = np.asarray(e) if i == 0: - new_shape = (CompactList.BUFFER_SIZE,) + e.shape[1:] + new_shape = (ArraySequence.BUFFER_SIZE,) + e.shape[1:] self._data = np.empty(new_shape, dtype=e.dtype) end = offset + len(e) if end >= len(self._data): - # Resize needed, adding `len(e)` new items plus some buffer. + # Resize needed, adding `len(e)` items plus some buffer. nb_points = len(self._data) - nb_points += len(e) + CompactList.BUFFER_SIZE + nb_points += len(e) + ArraySequence.BUFFER_SIZE self._data.resize((nb_points,) + self.common_shape) offsets.append(offset) @@ -63,14 +72,14 @@ def __init__(self, iterable=None): @property def common_shape(self): - """ Returns the matching shape of the elements in this compact list. """ + """ Matching shape of the elements in this array sequence. """ if self._data.ndim == 0: return () return self._data.shape[1:] def append(self, element): - """ Appends `element` to this compact list. + """ Appends :obj:`element` to this array sequence. Parameters ---------- @@ -81,7 +90,7 @@ def append(self, element): Notes ----- If you need to add multiple elements you should consider - `CompactList.extend`. + `ArraySequence.extend`. """ if self._data.ndim == 0: self._data = np.asarray(element).copy() @@ -98,20 +107,20 @@ def append(self, element): self._data = np.append(self._data, element, axis=0) def extend(self, elements): - """ Appends all `elements` to this compact list. + """ Appends all `elements` to this array sequence. Parameters ---------- - elements : list of ndarrays or :class:`CompactList` object + elements : list of ndarrays or :class:`ArraySequence` object If list of ndarrays, each ndarray will be concatenated along the - first dimension then appended to the data of this CompactList. - If :class:`CompactList` object, its data are simply appended to - the data of this CompactList. + first dimension then appended to the data of this ArraySequence. + If :class:`ArraySequence` object, its data are simply appended to + the data of this ArraySequence. Notes ----- - The shape of the elements to be added must match the one of the data - of this CompactList except for the first dimension. + The shape of the elements to be added must match the one of the + data of this :class:`ArraySequence` except for the first dimension. """ if self._data.ndim == 0: @@ -120,14 +129,15 @@ def extend(self, elements): next_offset = self._data.shape[0] - if isinstance(elements, CompactList): + if isinstance(elements, ArraySequence): self._data.resize((self._data.shape[0]+sum(elements._lengths), self._data.shape[1])) offsets = [] for offset, length in zip(elements._offsets, elements._lengths): offsets.append(next_offset) - self._data[next_offset:next_offset+length] = elements._data[offset:offset+length] + chunk = elements._data[offset:offset+length] + self._data[next_offset:next_offset+length] = chunk next_offset += length self._lengths = np.r_[self._lengths, elements._lengths] @@ -137,14 +147,15 @@ def extend(self, elements): self._data = np.concatenate([self._data] + list(elements), axis=0) lengths = list(map(len, elements)) self._lengths = np.r_[self._lengths, lengths] - self._offsets = np.r_[self._offsets, np.cumsum([next_offset] + lengths)[:-1]] + self._offsets = np.r_[self._offsets, + np.cumsum([next_offset] + lengths)[:-1]] def copy(self): - """ Creates a copy of this :class:`CompactList` object. """ + """ Creates a copy of this :class:`ArraySequence` object. """ # We do not simply deepcopy this object since we might have a chance - # to use less memory. For example, if the compact list being copied - # is the result of a slicing operation on a compact list. - clist = CompactList() + # to use less memory. For example, if the array sequence being copied + # is the result of a slicing operation on a array sequence. + clist = ArraySequence() total_lengths = np.sum(self._lengths) clist._data = np.empty((total_lengths,) + self._data.shape[1:], dtype=self._data.dtype) @@ -153,7 +164,8 @@ def copy(self): offsets = [] for offset, length in zip(self._offsets, self._lengths): offsets.append(next_offset) - clist._data[next_offset:next_offset+length] = self._data[offset:offset+length] + chunk = self._data[offset:offset+length] + clist._data[next_offset:next_offset+length] = chunk next_offset += length clist._offsets = np.asarray(offsets) @@ -162,40 +174,46 @@ def copy(self): return clist def __getitem__(self, idx): - """ Gets element(s) through indexing. + """ Gets sequence(s) through advanced indexing. Parameters ---------- - idx : int or slice or list or ndarray of bool dtype or ndarray of int dtype - Index of the element(s) to get. + idx : int or slice or list or ndarray + If int, index of the element to retrieve. + If slice, use slicing to retrieve elements. + If list, indices of the elements to retrieve. + If ndarray with dtype int, indices of the elements to retrieve. + If ndarray with dtype bool, only retrieve selected elements. Returns ------- - ndarray(s) - When `idx` is an int, returns a single ndarray. - When `idx` is either a slice, a list or a ndarray, returns a list - of ndarrays. + ndarray or :class:`ArraySequence` + If `idx` is an int, returns the selected sequence. + Otherwise, returns a :class:`ArraySequence` object which is view + of the selected sequences. + """ if isinstance(idx, int) or isinstance(idx, np.integer): start = self._offsets[idx] return self._data[start:start+self._lengths[idx]] elif isinstance(idx, slice) or isinstance(idx, list): - clist = CompactList() + clist = ArraySequence() clist._data = self._data clist._offsets = self._offsets[idx] clist._lengths = self._lengths[idx] return clist - elif isinstance(idx, np.ndarray) and np.issubdtype(idx.dtype, np.integer): - clist = CompactList() + elif isinstance(idx, np.ndarray) and np.issubdtype(idx.dtype, + np.integer): + clist = ArraySequence() clist._data = self._data clist._offsets = self._offsets[idx] clist._lengths = self._lengths[idx] return clist elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: - clist = CompactList() + clist = ArraySequence() clist._data = self._data clist._offsets = [self._offsets[i] for i, take_it in enumerate(idx) if take_it] @@ -208,7 +226,7 @@ def __getitem__(self, idx): def __iter__(self): if len(self._lengths) != len(self._offsets): - raise ValueError("CompactList object corrupted:" + raise ValueError("ArraySequence object corrupted:" " len(self._lengths) != len(self._offsets)") for offset, lengths in zip(self._offsets, self._lengths): @@ -221,7 +239,7 @@ def __repr__(self): return repr(list(self)) def save(self, filename): - """ Saves this :class:`CompactList` object to a .npz file. """ + """ Saves this :class:`ArraySequence` object to a .npz file. """ np.savez(filename, data=self._data, offsets=self._offsets, @@ -229,7 +247,7 @@ def save(self, filename): @classmethod def from_filename(cls, filename): - """ Loads a :class:`CompactList` object from a .npz file. """ + """ Loads a :class:`ArraySequence` object from a .npz file. """ content = np.load(filename) clist = cls() clist._data = content["data"] diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py new file mode 100644 index 0000000000..57efcfae44 --- /dev/null +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -0,0 +1,300 @@ +import os +import unittest +import tempfile +import numpy as np + +from nose.tools import assert_equal, assert_raises, assert_true +from nibabel.testing import assert_arrays_equal +from numpy.testing import assert_array_equal +from nibabel.externals.six.moves import zip, zip_longest + +from ..array_sequence import ArraySequence + + +class TestArraySequence(unittest.TestCase): + + def setUp(self): + rng = np.random.RandomState(42) + self.data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + self.lengths = list(map(len, self.data)) + self.seq = ArraySequence(self.data) + + def test_creating_empty_arraysequence(self): + seq = ArraySequence() + assert_equal(len(seq), 0) + assert_equal(len(seq._offsets), 0) + assert_equal(len(seq._lengths), 0) + assert_equal(seq._data.ndim, 0) + assert_true(seq.common_shape == ()) + + def test_creating_arraysequence_from_list(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = list(map(len, data)) + + seq = ArraySequence(data) + assert_equal(len(seq), len(data)) + assert_equal(len(seq._offsets), len(data)) + assert_equal(len(seq._lengths), len(data)) + assert_equal(seq._data.shape[0], sum(lengths)) + assert_equal(seq._data.shape[1], 3) + assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) + assert_array_equal(seq._lengths, lengths) + assert_equal(seq.common_shape, data[0].shape[1:]) + + # Empty list + seq = ArraySequence([]) + assert_equal(len(seq), 0) + assert_equal(len(seq._offsets), 0) + assert_equal(len(seq._lengths), 0) + assert_equal(seq._data.ndim, 0) + assert_true(seq.common_shape == ()) + + # Force ArraySequence constructor to use buffering. + old_buffer_size = ArraySequence.BUFFER_SIZE + ArraySequence.BUFFER_SIZE = 1 + seq = ArraySequence(data) + assert_equal(len(seq), len(data)) + assert_equal(len(seq._offsets), len(data)) + assert_equal(len(seq._lengths), len(data)) + assert_equal(seq._data.shape[0], sum(lengths)) + assert_equal(seq._data.shape[1], 3) + assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) + assert_array_equal(seq._lengths, lengths) + assert_equal(seq.common_shape, data[0].shape[1:]) + ArraySequence.BUFFER_SIZE = old_buffer_size + + def test_creating_arraysequence_from_generator(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = list(map(len, data)) + + gen = (e for e in data) + seq = ArraySequence(gen) + assert_equal(len(seq), len(data)) + assert_equal(len(seq._offsets), len(data)) + assert_equal(len(seq._lengths), len(data)) + assert_equal(seq._data.shape[0], sum(lengths)) + assert_equal(seq._data.shape[1], 3) + assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) + assert_array_equal(seq._lengths, lengths) + assert_equal(seq.common_shape, data[0].shape[1:]) + + # Already consumed generator + seq = ArraySequence(gen) + assert_equal(len(seq), 0) + assert_equal(len(seq._offsets), 0) + assert_equal(len(seq._lengths), 0) + assert_equal(seq._data.ndim, 0) + assert_true(seq.common_shape == ()) + + def test_creating_arraysequence_from_arraysequence(self): + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = list(map(len, data)) + + seq = ArraySequence(data) + seq2 = ArraySequence(seq) + assert_equal(len(seq2), len(data)) + assert_equal(len(seq2._offsets), len(data)) + assert_equal(len(seq2._lengths), len(data)) + assert_equal(seq2._data.shape[0], sum(lengths)) + assert_equal(seq2._data.shape[1], 3) + assert_array_equal(seq2._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) + assert_array_equal(seq2._lengths, lengths) + assert_equal(seq2.common_shape, data[0].shape[1:]) + + def test_arraysequence_iter(self): + for e, d in zip(self.seq, self.data): + assert_array_equal(e, d) + + # Try iterate through a corrupted ArraySequence object. + seq = self.seq.copy() + seq._lengths = seq._lengths[::2] + assert_raises(ValueError, list, seq) + + def test_arraysequence_copy(self): + seq = self.seq.copy() + assert_array_equal(seq._data, self.seq._data) + assert_true(seq._data is not self.seq._data) + assert_array_equal(seq._offsets, self.seq._offsets) + assert_true(seq._offsets is not self.seq._offsets) + assert_array_equal(seq._lengths, self.seq._lengths) + assert_true(seq._lengths is not self.seq._lengths) + + assert_equal(seq.common_shape, self.seq.common_shape) + + # When taking a copy of a `ArraySequence` generated by slicing. + # Only needed data should be kept. + seq = self.seq[::2].copy() + + assert_true(seq._data.shape[0] < self.seq._data.shape[0]) + assert_true(len(seq) < len(self.seq)) + assert_true(seq._data is not self.seq._data) + assert_array_equal(seq._lengths, self.seq[::2]._lengths) + assert_array_equal(seq._offsets, + np.cumsum(np.r_[0, self.seq[::2]._lengths])[:-1]) + assert_arrays_equal(seq, self.seq[::2]) + + def test_arraysequence_append(self): + # Maybe not necessary if `self.setUp` is always called before a + # test method, anyways create a copy just in case. + seq = self.seq.copy() + + rng = np.random.RandomState(1234) + element = rng.rand(rng.randint(10, 50), *self.seq.common_shape) + seq.append(element) + assert_equal(len(seq), len(self.seq)+1) + assert_equal(seq._offsets[-1], len(self.seq._data)) + assert_equal(seq._lengths[-1], len(element)) + assert_array_equal(seq._data[-len(element):], element) + + # Append with different shape. + element = rng.rand(rng.randint(10, 50), 42) + assert_raises(ValueError, seq.append, element) + + # Append to an empty ArraySequence. + seq = ArraySequence() + rng = np.random.RandomState(1234) + shape = (2, 3, 4) + element = rng.rand(rng.randint(10, 50), *shape) + seq.append(element) + + assert_equal(len(seq), 1) + assert_equal(seq._offsets[-1], 0) + assert_equal(seq._lengths[-1], len(element)) + assert_array_equal(seq._data, element) + assert_equal(seq.common_shape, shape) + + def test_arraysequence_extend(self): + # Maybe not necessary if `self.setUp` is always called before a + # test method, anyways create a copy just in case. + seq = self.seq.copy() + + rng = np.random.RandomState(1234) + shape = self.seq.common_shape + new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(10)] + lengths = list(map(len, new_data)) + seq.extend(new_data) + assert_equal(len(seq), len(self.seq)+len(new_data)) + assert_array_equal(seq._offsets[-len(new_data):], + len(self.seq._data) + np.cumsum([0] + lengths[:-1])) + + assert_array_equal(seq._lengths[-len(new_data):], lengths) + assert_array_equal(seq._data[-sum(lengths):], + np.concatenate(new_data, axis=0)) + + # Extend with another `ArraySequence` object. + seq = self.seq.copy() + new_seq = ArraySequence(new_data) + seq.extend(new_seq) + assert_equal(len(seq), len(self.seq)+len(new_seq)) + assert_array_equal(seq._offsets[-len(new_seq):], + len(self.seq._data) + np.cumsum(np.r_[0, lengths[:-1]])) + + assert_array_equal(seq._lengths[-len(new_seq):], lengths) + assert_array_equal(seq._data[-sum(lengths):], new_seq._data) + + # Extend with another `ArraySequence` object that is a view (e.g. been sliced). + # Need to make sure we extend only the data we need. + seq = self.seq.copy() + new_seq = ArraySequence(new_data)[::2] + seq.extend(new_seq) + assert_equal(len(seq), len(self.seq)+len(new_seq)) + assert_equal(len(seq._data), len(self.seq._data)+sum(new_seq._lengths)) + assert_array_equal(seq._offsets[-len(new_seq):], + len(self.seq._data) + np.cumsum(np.r_[0, new_seq._lengths[:-1]])) + + assert_array_equal(seq._lengths[-len(new_seq):], lengths[::2]) + assert_array_equal(seq._data[-sum(new_seq._lengths):], new_seq.copy()._data) + assert_arrays_equal(seq[-len(new_seq):], new_seq) + + # Test extending an empty ArraySequence + seq = ArraySequence() + new_seq = ArraySequence(new_data) + seq.extend(new_seq) + assert_equal(len(seq), len(new_seq)) + assert_array_equal(seq._offsets, new_seq._offsets) + assert_array_equal(seq._lengths, new_seq._lengths) + assert_array_equal(seq._data, new_seq._data) + + def test_arraysequence_getitem(self): + # Get one item + for i, e in enumerate(self.seq): + assert_array_equal(self.seq[i], e) + + # Get multiple items (this will create a view). + indices = list(range(len(self.seq))) + seq_view = self.seq[indices] + assert_true(seq_view is not self.seq) + assert_true(seq_view._data is self.seq._data) + assert_true(seq_view._offsets is not self.seq._offsets) + assert_true(seq_view._lengths is not self.seq._lengths) + assert_array_equal(seq_view._offsets, self.seq._offsets) + assert_array_equal(seq_view._lengths, self.seq._lengths) + assert_arrays_equal(seq_view, self.seq) + + # Get multiple items using ndarray of data type. + for dtype in [np.int8, np.int16, np.int32, np.int64]: + seq_view = self.seq[np.array(indices, dtype=dtype)] + assert_true(seq_view is not self.seq) + assert_true(seq_view._data is self.seq._data) + assert_true(seq_view._offsets is not self.seq._offsets) + assert_true(seq_view._lengths is not self.seq._lengths) + assert_array_equal(seq_view._offsets, self.seq._offsets) + assert_array_equal(seq_view._lengths, self.seq._lengths) + for e1, e2 in zip_longest(seq_view, self.seq): + assert_array_equal(e1, e2) + + # Get slice (this will create a view). + seq_view = self.seq[::2] + assert_true(seq_view is not self.seq) + assert_true(seq_view._data is self.seq._data) + assert_array_equal(seq_view._offsets, self.seq._offsets[::2]) + assert_array_equal(seq_view._lengths, self.seq._lengths[::2]) + for i, e in enumerate(seq_view): + assert_array_equal(e, self.seq[i*2]) + + # Use advance indexing with ndarray of data type bool. + idx = np.array([False, True, True, False, True]) + seq_view = self.seq[idx] + assert_true(seq_view is not self.seq) + assert_true(seq_view._data is self.seq._data) + assert_array_equal(seq_view._offsets, + self.seq._offsets[idx]) + assert_array_equal(seq_view._lengths, + self.seq._lengths[idx]) + assert_array_equal(seq_view[0], self.seq[1]) + assert_array_equal(seq_view[1], self.seq[2]) + assert_array_equal(seq_view[2], self.seq[4]) + + # Test invalid indexing + assert_raises(TypeError, self.seq.__getitem__, 'abc') + + def test_arraysequence_repr(self): + # Test that calling repr on a ArraySequence object is not falling. + repr(self.seq) + + def test_save_and_load_arraysequence(self): + + # Test saving and loading an empty ArraySequence. + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + seq = ArraySequence() + seq.save(f) + f.seek(0, os.SEEK_SET) + loaded_seq = ArraySequence.from_filename(f) + assert_array_equal(loaded_seq._data, seq._data) + assert_array_equal(loaded_seq._offsets, seq._offsets) + assert_array_equal(loaded_seq._lengths, seq._lengths) + + # Test saving and loading a ArraySequence. + with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: + rng = np.random.RandomState(42) + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + seq = ArraySequence(data) + seq.save(f) + f.seek(0, os.SEEK_SET) + loaded_seq = ArraySequence.from_filename(f) + assert_array_equal(loaded_seq._data, seq._data) + assert_array_equal(loaded_seq._offsets, seq._offsets) + assert_array_equal(loaded_seq._lengths, seq._lengths) diff --git a/nibabel/streamlines/tests/test_compact_list.py b/nibabel/streamlines/tests/test_compact_list.py deleted file mode 100644 index d0ef04951c..0000000000 --- a/nibabel/streamlines/tests/test_compact_list.py +++ /dev/null @@ -1,300 +0,0 @@ -import os -import unittest -import tempfile -import numpy as np - -from nose.tools import assert_equal, assert_raises, assert_true -from nibabel.testing import assert_arrays_equal -from numpy.testing import assert_array_equal -from nibabel.externals.six.moves import zip, zip_longest - -from ..compact_list import CompactList - - -class TestCompactList(unittest.TestCase): - - def setUp(self): - rng = np.random.RandomState(42) - self.data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - self.lengths = list(map(len, self.data)) - self.clist = CompactList(self.data) - - def test_creating_empty_compactlist(self): - clist = CompactList() - assert_equal(len(clist), 0) - assert_equal(len(clist._offsets), 0) - assert_equal(len(clist._lengths), 0) - assert_equal(clist._data.ndim, 0) - assert_true(clist.common_shape == ()) - - def test_creating_compactlist_from_list(self): - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = list(map(len, data)) - - clist = CompactList(data) - assert_equal(len(clist), len(data)) - assert_equal(len(clist._offsets), len(data)) - assert_equal(len(clist._lengths), len(data)) - assert_equal(clist._data.shape[0], sum(lengths)) - assert_equal(clist._data.shape[1], 3) - assert_array_equal(clist._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) - assert_array_equal(clist._lengths, lengths) - assert_equal(clist.common_shape, data[0].shape[1:]) - - # Empty list - clist = CompactList([]) - assert_equal(len(clist), 0) - assert_equal(len(clist._offsets), 0) - assert_equal(len(clist._lengths), 0) - assert_equal(clist._data.ndim, 0) - assert_true(clist.common_shape == ()) - - # Force CompactList constructor to use buffering. - old_buffer_size = CompactList.BUFFER_SIZE - CompactList.BUFFER_SIZE = 1 - clist = CompactList(data) - assert_equal(len(clist), len(data)) - assert_equal(len(clist._offsets), len(data)) - assert_equal(len(clist._lengths), len(data)) - assert_equal(clist._data.shape[0], sum(lengths)) - assert_equal(clist._data.shape[1], 3) - assert_array_equal(clist._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) - assert_array_equal(clist._lengths, lengths) - assert_equal(clist.common_shape, data[0].shape[1:]) - CompactList.BUFFER_SIZE = old_buffer_size - - def test_creating_compactlist_from_generator(self): - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = list(map(len, data)) - - gen = (e for e in data) - clist = CompactList(gen) - assert_equal(len(clist), len(data)) - assert_equal(len(clist._offsets), len(data)) - assert_equal(len(clist._lengths), len(data)) - assert_equal(clist._data.shape[0], sum(lengths)) - assert_equal(clist._data.shape[1], 3) - assert_array_equal(clist._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) - assert_array_equal(clist._lengths, lengths) - assert_equal(clist.common_shape, data[0].shape[1:]) - - # Already consumed generator - clist = CompactList(gen) - assert_equal(len(clist), 0) - assert_equal(len(clist._offsets), 0) - assert_equal(len(clist._lengths), 0) - assert_equal(clist._data.ndim, 0) - assert_true(clist.common_shape == ()) - - def test_creating_compactlist_from_compact_list(self): - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = list(map(len, data)) - - clist = CompactList(data) - clist2 = CompactList(clist) - assert_equal(len(clist2), len(data)) - assert_equal(len(clist2._offsets), len(data)) - assert_equal(len(clist2._lengths), len(data)) - assert_equal(clist2._data.shape[0], sum(lengths)) - assert_equal(clist2._data.shape[1], 3) - assert_array_equal(clist2._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) - assert_array_equal(clist2._lengths, lengths) - assert_equal(clist2.common_shape, data[0].shape[1:]) - - def test_compactlist_iter(self): - for e, d in zip(self.clist, self.data): - assert_array_equal(e, d) - - # Try iterate through a corrupted CompactList object. - clist = self.clist.copy() - clist._lengths = clist._lengths[::2] - assert_raises(ValueError, list, clist) - - def test_compactlist_copy(self): - clist = self.clist.copy() - assert_array_equal(clist._data, self.clist._data) - assert_true(clist._data is not self.clist._data) - assert_array_equal(clist._offsets, self.clist._offsets) - assert_true(clist._offsets is not self.clist._offsets) - assert_array_equal(clist._lengths, self.clist._lengths) - assert_true(clist._lengths is not self.clist._lengths) - - assert_equal(clist.common_shape, self.clist.common_shape) - - # When taking a copy of a `CompactList` generated by slicing. - # Only needed data should be kept. - clist = self.clist[::2].copy() - - assert_true(clist._data.shape[0] < self.clist._data.shape[0]) - assert_true(len(clist) < len(self.clist)) - assert_true(clist._data is not self.clist._data) - assert_array_equal(clist._lengths, self.clist[::2]._lengths) - assert_array_equal(clist._offsets, - np.cumsum(np.r_[0, self.clist[::2]._lengths])[:-1]) - assert_arrays_equal(clist, self.clist[::2]) - - def test_compactlist_append(self): - # Maybe not necessary if `self.setUp` is always called before a - # test method, anyways create a copy just in case. - clist = self.clist.copy() - - rng = np.random.RandomState(1234) - element = rng.rand(rng.randint(10, 50), *self.clist.common_shape) - clist.append(element) - assert_equal(len(clist), len(self.clist)+1) - assert_equal(clist._offsets[-1], len(self.clist._data)) - assert_equal(clist._lengths[-1], len(element)) - assert_array_equal(clist._data[-len(element):], element) - - # Append with different shape. - element = rng.rand(rng.randint(10, 50), 42) - assert_raises(ValueError, clist.append, element) - - # Append to an empty CompactList. - clist = CompactList() - rng = np.random.RandomState(1234) - shape = (2, 3, 4) - element = rng.rand(rng.randint(10, 50), *shape) - clist.append(element) - - assert_equal(len(clist), 1) - assert_equal(clist._offsets[-1], 0) - assert_equal(clist._lengths[-1], len(element)) - assert_array_equal(clist._data, element) - assert_equal(clist.common_shape, shape) - - def test_compactlist_extend(self): - # Maybe not necessary if `self.setUp` is always called before a - # test method, anyways create a copy just in case. - clist = self.clist.copy() - - rng = np.random.RandomState(1234) - shape = self.clist.common_shape - new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(10)] - lengths = list(map(len, new_data)) - clist.extend(new_data) - assert_equal(len(clist), len(self.clist)+len(new_data)) - assert_array_equal(clist._offsets[-len(new_data):], - len(self.clist._data) + np.cumsum([0] + lengths[:-1])) - - assert_array_equal(clist._lengths[-len(new_data):], lengths) - assert_array_equal(clist._data[-sum(lengths):], - np.concatenate(new_data, axis=0)) - - # Extend with another `CompactList` object. - clist = self.clist.copy() - new_clist = CompactList(new_data) - clist.extend(new_clist) - assert_equal(len(clist), len(self.clist)+len(new_clist)) - assert_array_equal(clist._offsets[-len(new_clist):], - len(self.clist._data) + np.cumsum(np.r_[0, lengths[:-1]])) - - assert_array_equal(clist._lengths[-len(new_clist):], lengths) - assert_array_equal(clist._data[-sum(lengths):], new_clist._data) - - # Extend with another `CompactList` object that is a view (e.g. been sliced). - # Need to make sure we extend only the data we need. - clist = self.clist.copy() - new_clist = CompactList(new_data)[::2] - clist.extend(new_clist) - assert_equal(len(clist), len(self.clist)+len(new_clist)) - assert_equal(len(clist._data), len(self.clist._data)+sum(new_clist._lengths)) - assert_array_equal(clist._offsets[-len(new_clist):], - len(self.clist._data) + np.cumsum(np.r_[0, new_clist._lengths[:-1]])) - - assert_array_equal(clist._lengths[-len(new_clist):], lengths[::2]) - assert_array_equal(clist._data[-sum(new_clist._lengths):], new_clist.copy()._data) - assert_arrays_equal(clist[-len(new_clist):], new_clist) - - # Test extending an empty CompactList - clist = CompactList() - new_clist = CompactList(new_data) - clist.extend(new_clist) - assert_equal(len(clist), len(new_clist)) - assert_array_equal(clist._offsets, new_clist._offsets) - assert_array_equal(clist._lengths, new_clist._lengths) - assert_array_equal(clist._data, new_clist._data) - - def test_compactlist_getitem(self): - # Get one item - for i, e in enumerate(self.clist): - assert_array_equal(self.clist[i], e) - - # Get multiple items (this will create a view). - indices = list(range(len(self.clist))) - clist_view = self.clist[indices] - assert_true(clist_view is not self.clist) - assert_true(clist_view._data is self.clist._data) - assert_true(clist_view._offsets is not self.clist._offsets) - assert_true(clist_view._lengths is not self.clist._lengths) - assert_array_equal(clist_view._offsets, self.clist._offsets) - assert_array_equal(clist_view._lengths, self.clist._lengths) - assert_arrays_equal(clist_view, self.clist) - - # Get multiple items using ndarray of data type. - for dtype in [np.int8, np.int16, np.int32, np.int64]: - clist_view = self.clist[np.array(indices, dtype=dtype)] - assert_true(clist_view is not self.clist) - assert_true(clist_view._data is self.clist._data) - assert_true(clist_view._offsets is not self.clist._offsets) - assert_true(clist_view._lengths is not self.clist._lengths) - assert_array_equal(clist_view._offsets, self.clist._offsets) - assert_array_equal(clist_view._lengths, self.clist._lengths) - for e1, e2 in zip_longest(clist_view, self.clist): - assert_array_equal(e1, e2) - - # Get slice (this will create a view). - clist_view = self.clist[::2] - assert_true(clist_view is not self.clist) - assert_true(clist_view._data is self.clist._data) - assert_array_equal(clist_view._offsets, self.clist._offsets[::2]) - assert_array_equal(clist_view._lengths, self.clist._lengths[::2]) - for i, e in enumerate(clist_view): - assert_array_equal(e, self.clist[i*2]) - - # Use advance indexing with ndarray of data type bool. - idx = np.array([False, True, True, False, True]) - clist_view = self.clist[idx] - assert_true(clist_view is not self.clist) - assert_true(clist_view._data is self.clist._data) - assert_array_equal(clist_view._offsets, - self.clist._offsets[idx]) - assert_array_equal(clist_view._lengths, - self.clist._lengths[idx]) - assert_array_equal(clist_view[0], self.clist[1]) - assert_array_equal(clist_view[1], self.clist[2]) - assert_array_equal(clist_view[2], self.clist[4]) - - # Test invalid indexing - assert_raises(TypeError, self.clist.__getitem__, 'abc') - - def test_compactlist_repr(self): - # Test that calling repr on a CompactList object is not falling. - repr(self.clist) - - def test_save_and_load_compact_list(self): - - # Test saving and loading an empty CompactList. - with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: - clist = CompactList() - clist.save(f) - f.seek(0, os.SEEK_SET) - loaded_clist = CompactList.from_filename(f) - assert_array_equal(loaded_clist._data, clist._data) - assert_array_equal(loaded_clist._offsets, clist._offsets) - assert_array_equal(loaded_clist._lengths, clist._lengths) - - # Test saving and loading a CompactList. - with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - clist = CompactList(data) - clist.save(f) - f.seek(0, os.SEEK_SET) - loaded_clist = CompactList.from_filename(f) - assert_array_equal(loaded_clist._data, clist._data) - assert_array_equal(loaded_clist._offsets, clist._offsets) - assert_array_equal(loaded_clist._lengths, clist._lengths) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 32f25fc04a..bf1cc1ef91 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -4,7 +4,7 @@ from nibabel.affines import apply_affine -from .compact_list import CompactList +from .array_sequence import ArraySequence class UsageWarning(Warning): @@ -109,10 +109,10 @@ def __setitem__(self, key, value): self.store[key] = value class DataPerPointDict(DataDict): - """ Internal dictionary making sure data are :class:`CompactList` objects. """ + """ Internal dictionary making sure data are :class:`ArraySequence` objects. """ def __setitem__(self, key, value): - value = CompactList(value) + value = ArraySequence(value) # We make sure we have the right amount of values (i.e. same as # the total number of points of all streamlines in the tractogram). @@ -157,7 +157,7 @@ def streamlines(self): @streamlines.setter def streamlines(self, value): - self._streamlines = CompactList(value) + self._streamlines = ArraySequence(value) @property def data_per_streamline(self): diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 259e5f4b2f..67fafa0faa 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -17,7 +17,7 @@ from nibabel.volumeutils import (native_code, swapped_code) from nibabel.orientations import (aff2axcodes, axcodes2ornt) -from .compact_list import CompactList +from .array_sequence import ArraySequence from .tractogram_file import TractogramFile from .tractogram_file import DataError, HeaderError, HeaderWarning from .tractogram import TractogramItem, Tractogram, LazyTractogram @@ -503,12 +503,12 @@ def is_correct_format(cls, fileobj): return False @classmethod - def _create_compactlist_from_generator(cls, gen): - """ Creates a CompactList object from a generator yielding tuples of + def _create_arraysequence_from_generator(cls, gen): + """ Creates a ArraySequence object from a generator yielding tuples of points, scalars and properties. """ - streamlines = CompactList() - scalars = CompactList() + streamlines = ArraySequence() + scalars = ArraySequence() properties = np.array([]) gen = iter(gen) @@ -703,15 +703,15 @@ def _read(): tractogram = LazyTractogram.create_from(_read) else: - streamlines, scalars, properties = cls._create_compactlist_from_generator(trk_reader) + streamlines, scalars, properties = cls._create_arraysequence_from_generator(trk_reader) tractogram = Tractogram(streamlines) for scalar_name, slice_ in data_per_point_slice.items(): - clist = CompactList() - clist._data = scalars._data[:, slice_] - clist._offsets = scalars._offsets - clist._lengths = scalars._lengths - tractogram.data_per_point[scalar_name] = clist + seq = ArraySequence() + seq._data = scalars._data[:, slice_] + seq._offsets = scalars._offsets + seq._lengths = scalars._lengths + tractogram.data_per_point[scalar_name] = seq for property_name, slice_ in data_per_streamline_slice.items(): tractogram.data_per_streamline[property_name] = properties[:, slice_] From 2b95c2973ed5e69c7c0e22779ee54db8b9291c46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 10 Feb 2016 14:24:46 -0500 Subject: [PATCH 078/135] Added test for creating ArraySequence from arrays of different number of dimensions. --- .../streamlines/tests/test_array_sequence.py | 34 ++++++++++++------- 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 57efcfae44..02a2b05891 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -29,18 +29,6 @@ def test_creating_empty_arraysequence(self): def test_creating_arraysequence_from_list(self): rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = list(map(len, data)) - - seq = ArraySequence(data) - assert_equal(len(seq), len(data)) - assert_equal(len(seq._offsets), len(data)) - assert_equal(len(seq._lengths), len(data)) - assert_equal(seq._data.shape[0], sum(lengths)) - assert_equal(seq._data.shape[1], 3) - assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) - assert_array_equal(seq._lengths, lengths) - assert_equal(seq.common_shape, data[0].shape[1:]) # Empty list seq = ArraySequence([]) @@ -50,7 +38,29 @@ def test_creating_arraysequence_from_list(self): assert_equal(seq._data.ndim, 0) assert_true(seq.common_shape == ()) + # List of ndarrays. + N = 5 + nb_arrays = 10 + for ndim in range(0, N+1): + common_shape = tuple([rng.randint(1, 10) for _ in range(ndim-1)]) + data = [rng.rand(*(rng.randint(10, 50),) + common_shape) + for _ in range(nb_arrays)] + lengths = list(map(len, data)) + + seq = ArraySequence(data) + assert_equal(len(seq), len(data)) + assert_equal(len(seq), nb_arrays) + assert_equal(len(seq._offsets), nb_arrays) + assert_equal(len(seq._lengths), nb_arrays) + assert_equal(seq._data.shape[0], sum(lengths)) + assert_equal(seq._data.shape[1:], common_shape) + assert_equal(seq.common_shape, common_shape) + assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) + assert_array_equal(seq._lengths, lengths) + # Force ArraySequence constructor to use buffering. + data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] + lengths = list(map(len, data)) old_buffer_size = ArraySequence.BUFFER_SIZE ArraySequence.BUFFER_SIZE = 1 seq = ArraySequence(data) From c78d18b22e208d6e7ffeee10895a5a63d4851551 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 10 Feb 2016 15:57:11 -0500 Subject: [PATCH 079/135] Added validation in nib.streamlines.save --- nibabel/streamlines/__init__.py | 32 ++++++++++++------- nibabel/streamlines/tests/test_streamlines.py | 21 +++++++++++- nibabel/streamlines/tractogram_file.py | 4 +++ 3 files changed, 44 insertions(+), 13 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 6a1d8a87cd..4b94c0d013 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -1,20 +1,16 @@ import os +import warnings from ..externals.six import string_types from .header import Field from .array_sequence import ArraySequence from .tractogram import Tractogram, LazyTractogram -from .tractogram_file import TractogramFile +from .tractogram_file import ExtensionWarning from .trk import TrkFile -#from .tck import TckFile -#from .vtk import VtkFile # List of all supported formats -FORMATS = {".trk": TrkFile, - #".tck": TckFile, - #".vtk": VtkFile, - } +FORMATS = {".trk": TrkFile} def is_supported(fileobj): @@ -112,18 +108,30 @@ def save(tractogram, filename, **kwargs): filename : str Name of the file where the tractogram will be saved. \*\*kwargs : keyword arguments - Keyword arguments passed to :class:`TractogramFile` constructor. + Keyword arguments passed to :class:`TractogramFile` constructor. + Should not be specified if `tractogram` is already an instance of + :class:`TractogramFile`. """ - tractogram_file = tractogram + tractogram_file_class = detect_format(filename) if isinstance(tractogram, Tractogram): - # We have to guess the file format. - tractogram_file_class = detect_format(filename) - if tractogram_file_class is None: msg = "Unknown tractogram file format: '{}'".format(filename) raise ValueError(msg) tractogram_file = tractogram_file_class(tractogram, **kwargs) + else: # Assume it's a TractogramFile object. + tractogram_file = tractogram + if tractogram_file_class is None \ + or not isinstance(tractogram_file, tractogram_file_class): + msg = ("The extension you specified is unusual for the provided" + " 'TractogramFile' object.") + warnings.warn(msg, ExtensionWarning) + + if len(kwargs) > 0: + msg = ("A 'TractogramFile' object was provided, no need for" + " keyword arguments.") + raise ValueError(msg) + tractogram_file.save(filename) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 84b7b5c4b3..92893792ea 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -14,7 +14,7 @@ from .test_tractogram import assert_tractogram_equal from ..tractogram import Tractogram, LazyTractogram -from ..tractogram_file import TractogramFile +from ..tractogram_file import TractogramFile, ExtensionWarning from ..tractogram import UsageWarning from .. import trk @@ -203,6 +203,25 @@ def test_load_complex_file(self): assert_tractogram_equal(tfile.tractogram, tractogram) + def test_save_tractogram_file(self): + tractogram = Tractogram(self.streamlines) + trk_file = trk.TrkFile(tractogram) + + # No need for keyword arguments. + assert_raises(ValueError, nib.streamlines.save, + trk_file, "dummy.trk", header={}) + + # Wrong extension. + with clear_and_catch_warnings(record=True, + modules=[nib.streamlines]) as w: + trk_file = trk.TrkFile(tractogram) + assert_raises(ValueError, nib.streamlines.save, + trk_file, "dummy.tck", header={}) + + assert_equal(len(w), 1) + assert_true(issubclass(w[0].category, ExtensionWarning)) + assert_true("extension" in str(w[0].message)) + def test_save_empty_file(self): tractogram = Tractogram() for ext, cls in nib.streamlines.FORMATS.items(): diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index cdd0463d5d..c99d35bccc 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -4,6 +4,10 @@ from .header import Field +class ExtensionWarning(Warning): + pass + + class HeaderWarning(Warning): pass From d1f4dbeff349b8c76d30b0dc291bd08303b30ca7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 11 Feb 2016 01:38:27 -0500 Subject: [PATCH 080/135] Fixed typos + clean code --- nibabel/streamlines/__init__.py | 13 +-- nibabel/streamlines/array_sequence.py | 106 +++++++++--------- .../streamlines/tests/test_array_sequence.py | 2 +- nibabel/streamlines/tractogram.py | 11 +- nibabel/streamlines/tractogram_file.py | 3 - nibabel/streamlines/trk.py | 5 - nibabel/streamlines/utils.py | 1 - 7 files changed, 61 insertions(+), 80 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 4b94c0d013..8119865c4a 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -26,7 +26,6 @@ def is_supported(fileobj): Returns ------- is_supported : boolean - """ return detect_format(fileobj) is not None @@ -45,13 +44,11 @@ def detect_format(fileobj): ------- tractogram_file : :class:`TractogramFile` class The class type guessed from the content of `fileobj`. - """ for format in FORMATS.values(): try: if format.is_correct_format(fileobj): return format - except IOError: pass @@ -77,14 +74,13 @@ def load(fileobj, lazy_load=False, ref=None): Returns ------- - tractogram_file : :class:TractogramFile object - Returns an instance of a :class:TractogramFile containing data and + tractogram_file : :class:`TractogramFile` object + Returns an instance of a :class:`TractogramFile` containing data and metadata of the tractogram loaded from `fileobj`. Notes ----- The streamline coordinate (0,0,0) refers to the center of the voxel. - """ tractogram_file = detect_format(fileobj) @@ -111,7 +107,6 @@ def save(tractogram, filename, **kwargs): Keyword arguments passed to :class:`TractogramFile` constructor. Should not be specified if `tractogram` is already an instance of :class:`TractogramFile`. - """ tractogram_file_class = detect_format(filename) if isinstance(tractogram, Tractogram): @@ -123,8 +118,8 @@ def save(tractogram, filename, **kwargs): else: # Assume it's a TractogramFile object. tractogram_file = tractogram - if tractogram_file_class is None \ - or not isinstance(tractogram_file, tractogram_file_class): + if (tractogram_file_class is None or + not isinstance(tractogram_file, tractogram_file_class)): msg = ("The extension you specified is unusual for the provided" " 'TractogramFile' object.") warnings.warn(msg, ExtensionWarning) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 2f3a82739d..94078e6b71 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -1,6 +1,14 @@ import numpy as np +def is_array_sequence(obj): + """ Return True if `obj` is an array sequence. """ + try: + return obj.is_array_sequence + except AttributeError: + return False + + class ArraySequence(object): """ Sequence of ndarrays having variable first dimension sizes. @@ -9,10 +17,9 @@ class ArraySequence(object): remaining dimensions. More generally, an instance of :class:`ArraySequence` of length $N$ is - compoosed of $N$ ndarrays of shape $(d_n, d_2, ... d_D)$ where - $n \in [1,N]$, $d_n$ might vary from one ndarray to another and - $d_2, ..., d_D)$ have to be the same for every ndarray. - + composed of $N$ ndarrays of shape $(d_1, d_2, ... d_D)$ where $d_1$ + can vary in length between arrays but $(d_2, ..., d_D)$ have to be the + same for every ndarray. """ BUFFER_SIZE = 87382*4 # About 4 Mb if item shape is 3 (e.g. 3D points). @@ -27,48 +34,55 @@ def __init__(self, iterable=None): from array-like objects yielded by the iterable. If :class:`ArraySequence`, create a view (no memory is allocated). For an actual copy use :meth:`.copy` instead. - """ # Create new empty `ArraySequence` object. self._data = np.array(0) self._offsets = np.array([], dtype=np.intp) self._lengths = np.array([], dtype=np.intp) - if isinstance(iterable, ArraySequence): + if iterable is None: + return + + if is_array_sequence(iterable): # Create a view. self._data = iterable._data self._offsets = iterable._offsets self._lengths = iterable._lengths + return - elif iterable is not None: - offsets = [] - lengths = [] - # Initialize the `ArraySequence` object from iterable's item. - offset = 0 - for i, e in enumerate(iterable): - e = np.asarray(e) - if i == 0: - new_shape = (ArraySequence.BUFFER_SIZE,) + e.shape[1:] - self._data = np.empty(new_shape, dtype=e.dtype) - - end = offset + len(e) - if end >= len(self._data): - # Resize needed, adding `len(e)` items plus some buffer. - nb_points = len(self._data) - nb_points += len(e) + ArraySequence.BUFFER_SIZE - self._data.resize((nb_points,) + self.common_shape) - - offsets.append(offset) - lengths.append(len(e)) - self._data[offset:offset+len(e)] = e - offset += len(e) - - self._offsets = np.asarray(offsets) - self._lengths = np.asarray(lengths) - - # Clear unused memory. - if self._data.ndim != 0: - self._data.resize((offset,) + self.common_shape) + # Add elements of the iterable. + offsets = [] + lengths = [] + # Initialize the `ArraySequence` object from iterable's item. + offset = 0 + for i, e in enumerate(iterable): + e = np.asarray(e) + if i == 0: + new_shape = (ArraySequence.BUFFER_SIZE,) + e.shape[1:] + self._data = np.empty(new_shape, dtype=e.dtype) + + end = offset + len(e) + if end >= len(self._data): + # Resize needed, adding `len(e)` items plus some buffer. + nb_points = len(self._data) + nb_points += len(e) + ArraySequence.BUFFER_SIZE + self._data.resize((nb_points,) + self.common_shape) + + offsets.append(offset) + lengths.append(len(e)) + self._data[offset:offset+len(e)] = e + offset += len(e) + + self._offsets = np.asarray(offsets) + self._lengths = np.asarray(lengths) + + # Clear unused memory. + if self._data.ndim != 0: + self._data.resize((offset,) + self.common_shape) + + @property + def is_array_sequence(self): + return True @property def common_shape(self): @@ -121,7 +135,6 @@ def extend(self, elements): ----- The shape of the elements to be added must match the one of the data of this :class:`ArraySequence` except for the first dimension. - """ if self._data.ndim == 0: elem = np.asarray(elements[0]) @@ -129,7 +142,7 @@ def extend(self, elements): next_offset = self._data.shape[0] - if isinstance(elements, ArraySequence): + if is_array_sequence(elements): self._data.resize((self._data.shape[0]+sum(elements._lengths), self._data.shape[1])) @@ -191,36 +204,27 @@ def __getitem__(self, idx): If `idx` is an int, returns the selected sequence. Otherwise, returns a :class:`ArraySequence` object which is view of the selected sequences. - """ - if isinstance(idx, int) or isinstance(idx, np.integer): + if isinstance(idx, (int, np.integer)): start = self._offsets[idx] return self._data[start:start+self._lengths[idx]] - elif isinstance(idx, slice) or isinstance(idx, list): + elif isinstance(idx, (slice, list)): clist = ArraySequence() clist._data = self._data clist._offsets = self._offsets[idx] clist._lengths = self._lengths[idx] return clist - elif isinstance(idx, np.ndarray) and np.issubdtype(idx.dtype, - np.integer): + elif (isinstance(idx, np.ndarray) and + (np.issubdtype(idx.dtype, np.integer) or + np.issubdtype(idx.dtype, np.bool))): clist = ArraySequence() clist._data = self._data clist._offsets = self._offsets[idx] clist._lengths = self._lengths[idx] return clist - elif isinstance(idx, np.ndarray) and idx.dtype == np.bool: - clist = ArraySequence() - clist._data = self._data - clist._offsets = [self._offsets[i] - for i, take_it in enumerate(idx) if take_it] - clist._lengths = [self._lengths[i] - for i, take_it in enumerate(idx) if take_it] - return clist - raise TypeError("Index must be either an int, a slice, a list of int" " or a ndarray of bool! Not " + str(type(idx))) diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 02a2b05891..ea62b7acaa 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -265,7 +265,7 @@ def test_arraysequence_getitem(self): for i, e in enumerate(seq_view): assert_array_equal(e, self.seq[i*2]) - # Use advance indexing with ndarray of data type bool. + # Use advanced indexing with ndarray of data type bool. idx = np.array([False, True, True, False, True]) seq_view = self.seq[idx] assert_true(seq_view is not self.seq) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index bf1cc1ef91..f095c32c9a 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -31,7 +31,6 @@ class TractogramItem(object): particular streamline. Each key `k` is mapped to a ndarray of shape (Nt, Mk), where `Nt` is the number of points of this streamline and `Mk` is the dimension of the data associated with key `k`. - """ def __init__(self, streamline, data_for_streamline, data_for_points): self.streamline = np.asarray(streamline) @@ -56,7 +55,6 @@ class Tractogram(object): affine_to_rasmm : ndarray shape (4, 4) Affine that brings the streamlines back to *RAS+* and *mm* space where coordinate (0,0,0) refers to the center of the voxel. - """ class DataDict(collections.MutableMapping): def __init__(self, tractogram, *args, **kwargs): @@ -144,7 +142,6 @@ def __init__(self, streamlines=None, points for a particular streamline t and M is the number of scalars associated to each point (excluding the three coordinates). - """ self.streamlines = streamlines self.data_per_streamline = data_per_streamline @@ -196,7 +193,7 @@ def __getitem__(self, idx): for key in self.data_per_point: data_per_point[key] = self.data_per_point[key][idx] - if isinstance(idx, int) or isinstance(idx, np.integer): + if isinstance(idx, (int, np.integer)): return TractogramItem(pts, data_per_streamline, data_per_point) return Tractogram(pts, data_per_streamline, data_per_point) @@ -243,7 +240,6 @@ def apply_affine(self, affine, lazy=False): it returns a :class:`LazyTractogram` object, otherwise it returns a reference to this :class:`Tractogram` object with updated streamlines. - """ if lazy: lazy_tractogram = LazyTractogram.from_tractogram(self) @@ -277,7 +273,6 @@ class LazyTractogram(Tractogram): ----- If provided, `scalars` and `properties` must yield the same number of values as `streamlines`. - """ class LazyDict(collections.MutableMapping): @@ -329,7 +324,6 @@ def __init__(self, streamlines=None, is the number of points of that streamline t and M is the number of scalars associated to each point (excluding the three coordinates). - """ super(LazyTractogram, self).__init__(streamlines, data_per_streamline, @@ -351,7 +345,6 @@ def from_tractogram(cls, tractogram): ------- lazy_tractogram : :class:`LazyTractogram` object New lazy tractogram. - """ data_per_streamline = {} for key, value in tractogram.data_per_streamline.items(): @@ -385,7 +378,6 @@ def create_from(cls, data_func): ------- lazy_tractogram : :class:`LazyTractogram` object New lazy tractogram. - """ if not callable(data_func): raise TypeError("`data_func` must be a coroutine.") @@ -541,7 +533,6 @@ def apply_affine(self, affine): ------- lazy_tractogram : :class:`LazyTractogram` object Reference to this instance of :class:`LazyTractogram`. - """ # Update the affine that will be applied when returning streamlines. self._affine_to_apply = np.dot(affine, self._affine_to_apply) diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index c99d35bccc..9a271df4ee 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -95,7 +95,6 @@ def is_correct_format(cls, fileobj): is_correct_format : {True, False} Returns True if `fileobj` is in the right streamlines file format, otherwise returns False. - """ raise NotImplementedError() @@ -118,7 +117,6 @@ def load(cls, fileobj, lazy_load=True): tractogram_file : :class:`TractogramFile` object Returns an object containing tractogram data and header information. - """ raise NotImplementedError() @@ -131,6 +129,5 @@ def save(self, fileobj): fileobj : string or file-like object If string, a filename; otherwise an open file-like object opened and ready to write. - """ raise NotImplementedError() diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 67fafa0faa..b9ea297b53 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -105,7 +105,6 @@ class TrkReader(object): Thus, streamlines are shifted of half a voxel on load and are shifted back on save. - """ def __init__(self, fileobj): self.fileobj = fileobj @@ -433,7 +432,6 @@ class TrkFile(TractogramFile): Thus, streamlines are shifted of half a voxel on load and are shifted back on save. - """ # Contants @@ -613,7 +611,6 @@ def load(cls, fileobj, lazy_load=False): Streamlines of the returned tractogram are assumed to be in *RAS* and *mm* space where coordinate (0,0,0) refers to the center of the voxel. - """ trk_reader = TrkReader(fileobj) @@ -730,7 +727,6 @@ def save(self, fileobj): If string, a filename; otherwise an open file-like object pointing to TRK file (and ready to read from the beginning of the TRK header data). - """ trk_writer = TrkWriter(fileobj, self.header) trk_writer.write(self.tractogram) @@ -749,7 +745,6 @@ def __str__(self): ------- info : string Header information relevant to the TRK format. - """ hdr = self.header diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index ce46c8f3eb..254e9b3442 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -20,7 +20,6 @@ def get_affine_from_reference(ref): ------- affine : ndarray (4, 4) Transformation matrix mapping voxel space to RAS+mm space. - """ if type(ref) is np.ndarray: if ref.shape != (4, 4): From 66fea1913e45220b5d3254e86f83ac97138f5918 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 11 Feb 2016 16:02:22 -0500 Subject: [PATCH 081/135] Refactored test_array_sequence following @matthew-brett's suggestions. --- nibabel/streamlines/array_sequence.py | 51 ++- .../streamlines/tests/test_array_sequence.py | 381 ++++++++---------- 2 files changed, 192 insertions(+), 240 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 94078e6b71..32521eb20c 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -36,6 +36,7 @@ def __init__(self, iterable=None): For an actual copy use :meth:`.copy` instead. """ # Create new empty `ArraySequence` object. + self._is_view = False self._data = np.array(0) self._offsets = np.array([], dtype=np.intp) self._lengths = np.array([], dtype=np.intp) @@ -48,6 +49,7 @@ def __init__(self, iterable=None): self._data = iterable._data self._offsets = iterable._offsets self._lengths = iterable._lengths + self._is_view = True return # Add elements of the iterable. @@ -136,6 +138,9 @@ def extend(self, elements): The shape of the elements to be added must match the one of the data of this :class:`ArraySequence` except for the first dimension. """ + if len(elements) == 0: + return + if self._data.ndim == 0: elem = np.asarray(elements[0]) self._data = np.zeros((0, elem.shape[1]), dtype=elem.dtype) @@ -168,23 +173,23 @@ def copy(self): # We do not simply deepcopy this object since we might have a chance # to use less memory. For example, if the array sequence being copied # is the result of a slicing operation on a array sequence. - clist = ArraySequence() + seq = ArraySequence() total_lengths = np.sum(self._lengths) - clist._data = np.empty((total_lengths,) + self._data.shape[1:], - dtype=self._data.dtype) + seq._data = np.empty((total_lengths,) + self._data.shape[1:], + dtype=self._data.dtype) next_offset = 0 offsets = [] for offset, length in zip(self._offsets, self._lengths): offsets.append(next_offset) chunk = self._data[offset:offset+length] - clist._data[next_offset:next_offset+length] = chunk + seq._data[next_offset:next_offset+length] = chunk next_offset += length - clist._offsets = np.asarray(offsets) - clist._lengths = self._lengths.copy() + seq._offsets = np.asarray(offsets) + seq._lengths = self._lengths.copy() - return clist + return seq def __getitem__(self, idx): """ Gets sequence(s) through advanced indexing. @@ -210,20 +215,22 @@ def __getitem__(self, idx): return self._data[start:start+self._lengths[idx]] elif isinstance(idx, (slice, list)): - clist = ArraySequence() - clist._data = self._data - clist._offsets = self._offsets[idx] - clist._lengths = self._lengths[idx] - return clist + seq = ArraySequence() + seq._data = self._data + seq._offsets = self._offsets[idx] + seq._lengths = self._lengths[idx] + seq._is_view = True + return seq elif (isinstance(idx, np.ndarray) and (np.issubdtype(idx.dtype, np.integer) or np.issubdtype(idx.dtype, np.bool))): - clist = ArraySequence() - clist._data = self._data - clist._offsets = self._offsets[idx] - clist._lengths = self._lengths[idx] - return clist + seq = ArraySequence() + seq._data = self._data + seq._offsets = self._offsets[idx] + seq._lengths = self._lengths[idx] + seq._is_view = True + return seq raise TypeError("Index must be either an int, a slice, a list of int" " or a ndarray of bool! Not " + str(type(idx))) @@ -253,8 +260,8 @@ def save(self, filename): def from_filename(cls, filename): """ Loads a :class:`ArraySequence` object from a .npz file. """ content = np.load(filename) - clist = cls() - clist._data = content["data"] - clist._offsets = content["offsets"] - clist._lengths = content["lengths"] - return clist + seq = cls() + seq._data = content["data"] + seq._offsets = content["offsets"] + seq._lengths = content["lengths"] + return seq diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index ea62b7acaa..fe8e9da9da 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -8,285 +8,232 @@ from numpy.testing import assert_array_equal from nibabel.externals.six.moves import zip, zip_longest -from ..array_sequence import ArraySequence +from ..array_sequence import ArraySequence, is_array_sequence -class TestArraySequence(unittest.TestCase): +SEQ_DATA = {} + + +def setup(): + global SEQ_DATA + rng = np.random.RandomState(42) + SEQ_DATA['rng'] = rng + SEQ_DATA['data'] = generate_data(nb_arrays=10, common_shape=(3,), rng=rng) + SEQ_DATA['seq'] = ArraySequence(SEQ_DATA['data']) + + +def generate_data(nb_arrays, common_shape, rng): + data = [rng.rand(*(rng.randint(10, 50),) + common_shape) + for _ in range(nb_arrays)] + return data + + +def check_empty_arr_seq(seq): + assert_equal(len(seq), 0) + assert_equal(len(seq._offsets), 0) + assert_equal(len(seq._lengths), 0) + assert_equal(seq._data.ndim, 0) + assert_true(seq.common_shape == ()) + + +def check_arr_seq(seq, arrays): + lengths = list(map(len, arrays)) + assert_true(is_array_sequence(seq)) + assert_equal(len(seq), len(arrays)) + assert_equal(len(seq._offsets), len(arrays)) + assert_equal(len(seq._lengths), len(arrays)) + assert_equal(seq._data.shape[1:], arrays[0].shape[1:]) + assert_equal(seq.common_shape, arrays[0].shape[1:]) + assert_arrays_equal(seq, arrays) + + # If seq is a view, there order of internal data is not guarantied. + if seq._is_view: + # The only thing we can check is the _lengths. + assert_array_equal(sorted(seq._lengths), sorted(lengths)) + else: + assert_equal(seq._data.shape[0], sum(lengths)) + assert_array_equal(seq._data, np.concatenate(arrays, axis=0)) + assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) + assert_array_equal(seq._lengths, lengths) + + +def check_arr_seq_view(seq_view, seq): + assert_true(seq_view._is_view) + assert_true(seq_view is not seq) + assert_true(seq_view._data is seq._data) + assert_true(seq_view._offsets is not seq._offsets) + assert_true(seq_view._lengths is not seq._lengths) - def setUp(self): - rng = np.random.RandomState(42) - self.data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - self.lengths = list(map(len, self.data)) - self.seq = ArraySequence(self.data) + +class TestArraySequence(unittest.TestCase): def test_creating_empty_arraysequence(self): - seq = ArraySequence() - assert_equal(len(seq), 0) - assert_equal(len(seq._offsets), 0) - assert_equal(len(seq._lengths), 0) - assert_equal(seq._data.ndim, 0) - assert_true(seq.common_shape == ()) + check_empty_arr_seq(ArraySequence()) def test_creating_arraysequence_from_list(self): - rng = np.random.RandomState(42) - # Empty list - seq = ArraySequence([]) - assert_equal(len(seq), 0) - assert_equal(len(seq._offsets), 0) - assert_equal(len(seq._lengths), 0) - assert_equal(seq._data.ndim, 0) - assert_true(seq.common_shape == ()) + check_empty_arr_seq(ArraySequence([])) # List of ndarrays. N = 5 - nb_arrays = 10 for ndim in range(0, N+1): - common_shape = tuple([rng.randint(1, 10) for _ in range(ndim-1)]) - data = [rng.rand(*(rng.randint(10, 50),) + common_shape) - for _ in range(nb_arrays)] - lengths = list(map(len, data)) - - seq = ArraySequence(data) - assert_equal(len(seq), len(data)) - assert_equal(len(seq), nb_arrays) - assert_equal(len(seq._offsets), nb_arrays) - assert_equal(len(seq._lengths), nb_arrays) - assert_equal(seq._data.shape[0], sum(lengths)) - assert_equal(seq._data.shape[1:], common_shape) - assert_equal(seq.common_shape, common_shape) - assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) - assert_array_equal(seq._lengths, lengths) + common_shape = tuple([SEQ_DATA['rng'].randint(1, 10) + for _ in range(ndim-1)]) + data = generate_data(nb_arrays=10, common_shape=common_shape, + rng=SEQ_DATA['rng']) + check_arr_seq(ArraySequence(data), data) # Force ArraySequence constructor to use buffering. - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = list(map(len, data)) old_buffer_size = ArraySequence.BUFFER_SIZE ArraySequence.BUFFER_SIZE = 1 - seq = ArraySequence(data) - assert_equal(len(seq), len(data)) - assert_equal(len(seq._offsets), len(data)) - assert_equal(len(seq._lengths), len(data)) - assert_equal(seq._data.shape[0], sum(lengths)) - assert_equal(seq._data.shape[1], 3) - assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) - assert_array_equal(seq._lengths, lengths) - assert_equal(seq.common_shape, data[0].shape[1:]) + check_arr_seq(ArraySequence(SEQ_DATA['data']), SEQ_DATA['data']) ArraySequence.BUFFER_SIZE = old_buffer_size def test_creating_arraysequence_from_generator(self): - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = list(map(len, data)) - - gen = (e for e in data) - seq = ArraySequence(gen) - assert_equal(len(seq), len(data)) - assert_equal(len(seq._offsets), len(data)) - assert_equal(len(seq._lengths), len(data)) - assert_equal(seq._data.shape[0], sum(lengths)) - assert_equal(seq._data.shape[1], 3) - assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) - assert_array_equal(seq._lengths, lengths) - assert_equal(seq.common_shape, data[0].shape[1:]) + gen = (e for e in SEQ_DATA['data']) + check_arr_seq(ArraySequence(gen), SEQ_DATA['data']) # Already consumed generator - seq = ArraySequence(gen) - assert_equal(len(seq), 0) - assert_equal(len(seq._offsets), 0) - assert_equal(len(seq._lengths), 0) - assert_equal(seq._data.ndim, 0) - assert_true(seq.common_shape == ()) + check_empty_arr_seq(ArraySequence(gen)) def test_creating_arraysequence_from_arraysequence(self): - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - lengths = list(map(len, data)) - - seq = ArraySequence(data) - seq2 = ArraySequence(seq) - assert_equal(len(seq2), len(data)) - assert_equal(len(seq2._offsets), len(data)) - assert_equal(len(seq2._lengths), len(data)) - assert_equal(seq2._data.shape[0], sum(lengths)) - assert_equal(seq2._data.shape[1], 3) - assert_array_equal(seq2._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) - assert_array_equal(seq2._lengths, lengths) - assert_equal(seq2.common_shape, data[0].shape[1:]) + seq = ArraySequence(SEQ_DATA['data']) + check_arr_seq(ArraySequence(seq), SEQ_DATA['data']) + + # From an empty ArraySequence + seq = ArraySequence() + check_empty_arr_seq(ArraySequence(seq)) def test_arraysequence_iter(self): - for e, d in zip(self.seq, self.data): - assert_array_equal(e, d) + assert_arrays_equal(SEQ_DATA['seq'], SEQ_DATA['data']) - # Try iterate through a corrupted ArraySequence object. - seq = self.seq.copy() + # Try iterating through a corrupted ArraySequence object. + seq = SEQ_DATA['seq'].copy() seq._lengths = seq._lengths[::2] assert_raises(ValueError, list, seq) def test_arraysequence_copy(self): - seq = self.seq.copy() - assert_array_equal(seq._data, self.seq._data) - assert_true(seq._data is not self.seq._data) - assert_array_equal(seq._offsets, self.seq._offsets) - assert_true(seq._offsets is not self.seq._offsets) - assert_array_equal(seq._lengths, self.seq._lengths) - assert_true(seq._lengths is not self.seq._lengths) - - assert_equal(seq.common_shape, self.seq.common_shape) - - # When taking a copy of a `ArraySequence` generated by slicing. - # Only needed data should be kept. - seq = self.seq[::2].copy() - - assert_true(seq._data.shape[0] < self.seq._data.shape[0]) - assert_true(len(seq) < len(self.seq)) - assert_true(seq._data is not self.seq._data) - assert_array_equal(seq._lengths, self.seq[::2]._lengths) - assert_array_equal(seq._offsets, - np.cumsum(np.r_[0, self.seq[::2]._lengths])[:-1]) - assert_arrays_equal(seq, self.seq[::2]) + seq = SEQ_DATA['seq'].copy() + assert_array_equal(seq._data, SEQ_DATA['seq']._data) + assert_true(seq._data is not SEQ_DATA['seq']._data) + assert_array_equal(seq._offsets, SEQ_DATA['seq']._offsets) + assert_true(seq._offsets is not SEQ_DATA['seq']._offsets) + assert_array_equal(seq._lengths, SEQ_DATA['seq']._lengths) + assert_true(seq._lengths is not SEQ_DATA['seq']._lengths) + assert_equal(seq.common_shape, SEQ_DATA['seq'].common_shape) + + # Taking a copy of an `ArraySequence` generated by slicing. + # Only keep needed data. + seq = SEQ_DATA['seq'][::2].copy() + check_arr_seq(seq, SEQ_DATA['data'][::2]) + assert_true(seq._data is not SEQ_DATA['seq']._data) def test_arraysequence_append(self): - # Maybe not necessary if `self.setUp` is always called before a - # test method, anyways create a copy just in case. - seq = self.seq.copy() + element = generate_data(nb_arrays=1, + common_shape=SEQ_DATA['seq'].common_shape, + rng=SEQ_DATA['rng'])[0] - rng = np.random.RandomState(1234) - element = rng.rand(rng.randint(10, 50), *self.seq.common_shape) + # Append a new element. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. seq.append(element) - assert_equal(len(seq), len(self.seq)+1) - assert_equal(seq._offsets[-1], len(self.seq._data)) - assert_equal(seq._lengths[-1], len(element)) - assert_array_equal(seq._data[-len(element):], element) - - # Append with different shape. - element = rng.rand(rng.randint(10, 50), 42) - assert_raises(ValueError, seq.append, element) + check_arr_seq(seq, SEQ_DATA['data'] + [element]) # Append to an empty ArraySequence. seq = ArraySequence() - rng = np.random.RandomState(1234) - shape = (2, 3, 4) - element = rng.rand(rng.randint(10, 50), *shape) seq.append(element) + check_arr_seq(seq, [element]) - assert_equal(len(seq), 1) - assert_equal(seq._offsets[-1], 0) - assert_equal(seq._lengths[-1], len(element)) - assert_array_equal(seq._data, element) - assert_equal(seq.common_shape, shape) + # Append an element with different shape. + element = generate_data(nb_arrays=1, + common_shape=SEQ_DATA['seq'].common_shape*2, + rng=SEQ_DATA['rng'])[0] + assert_raises(ValueError, seq.append, element) def test_arraysequence_extend(self): - # Maybe not necessary if `self.setUp` is always called before a - # test method, anyways create a copy just in case. - seq = self.seq.copy() - - rng = np.random.RandomState(1234) - shape = self.seq.common_shape - new_data = [rng.rand(rng.randint(10, 50), *shape) for _ in range(10)] - lengths = list(map(len, new_data)) - seq.extend(new_data) - assert_equal(len(seq), len(self.seq)+len(new_data)) - assert_array_equal(seq._offsets[-len(new_data):], - len(self.seq._data) + np.cumsum([0] + lengths[:-1])) + new_data = generate_data(nb_arrays=10, + common_shape=SEQ_DATA['seq'].common_shape, + rng=SEQ_DATA['rng']) - assert_array_equal(seq._lengths[-len(new_data):], lengths) - assert_array_equal(seq._data[-sum(lengths):], - np.concatenate(new_data, axis=0)) + # Extend with an empty list. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.extend([]) + check_arr_seq(seq, SEQ_DATA['data']) - # Extend with another `ArraySequence` object. - seq = self.seq.copy() - new_seq = ArraySequence(new_data) - seq.extend(new_seq) - assert_equal(len(seq), len(self.seq)+len(new_seq)) - assert_array_equal(seq._offsets[-len(new_seq):], - len(self.seq._data) + np.cumsum(np.r_[0, lengths[:-1]])) + # Extend with a list of ndarrays. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.extend(new_data) + check_arr_seq(seq, SEQ_DATA['data'] + new_data) - assert_array_equal(seq._lengths[-len(new_seq):], lengths) - assert_array_equal(seq._data[-sum(lengths):], new_seq._data) + # Extend with another `ArraySequence` object. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.extend(ArraySequence(new_data)) + check_arr_seq(seq, SEQ_DATA['data'] + new_data) - # Extend with another `ArraySequence` object that is a view (e.g. been sliced). + # Extend with an `ArraySequence` view (e.g. been sliced). # Need to make sure we extend only the data we need. - seq = self.seq.copy() - new_seq = ArraySequence(new_data)[::2] - seq.extend(new_seq) - assert_equal(len(seq), len(self.seq)+len(new_seq)) - assert_equal(len(seq._data), len(self.seq._data)+sum(new_seq._lengths)) - assert_array_equal(seq._offsets[-len(new_seq):], - len(self.seq._data) + np.cumsum(np.r_[0, new_seq._lengths[:-1]])) - - assert_array_equal(seq._lengths[-len(new_seq):], lengths[::2]) - assert_array_equal(seq._data[-sum(new_seq._lengths):], new_seq.copy()._data) - assert_arrays_equal(seq[-len(new_seq):], new_seq) + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.extend(ArraySequence(new_data)[::2]) + check_arr_seq(seq, SEQ_DATA['data'] + new_data[::2]) # Test extending an empty ArraySequence seq = ArraySequence() - new_seq = ArraySequence(new_data) - seq.extend(new_seq) - assert_equal(len(seq), len(new_seq)) - assert_array_equal(seq._offsets, new_seq._offsets) - assert_array_equal(seq._lengths, new_seq._lengths) - assert_array_equal(seq._data, new_seq._data) + seq.extend(ArraySequence()) + check_empty_arr_seq(seq) + + seq.extend(SEQ_DATA['seq']) + check_arr_seq(seq, SEQ_DATA['data']) + + # Extend with elements of different shape. + data = generate_data(nb_arrays=10, + common_shape=SEQ_DATA['seq'].common_shape*2, + rng=SEQ_DATA['rng']) + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + assert_raises(ValueError, seq.extend, data) def test_arraysequence_getitem(self): # Get one item - for i, e in enumerate(self.seq): - assert_array_equal(self.seq[i], e) - - # Get multiple items (this will create a view). - indices = list(range(len(self.seq))) - seq_view = self.seq[indices] - assert_true(seq_view is not self.seq) - assert_true(seq_view._data is self.seq._data) - assert_true(seq_view._offsets is not self.seq._offsets) - assert_true(seq_view._lengths is not self.seq._lengths) - assert_array_equal(seq_view._offsets, self.seq._offsets) - assert_array_equal(seq_view._lengths, self.seq._lengths) - assert_arrays_equal(seq_view, self.seq) - - # Get multiple items using ndarray of data type. + for i, e in enumerate(SEQ_DATA['seq']): + assert_array_equal(SEQ_DATA['seq'][i], e) + + # Get all items using indexing (creates a view). + indices = list(range(len(SEQ_DATA['seq']))) + seq_view = SEQ_DATA['seq'][indices] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + # We took all elements so the view should match the original. + check_arr_seq(seq_view, SEQ_DATA['seq']) + + # Get multiple items using ndarray of dtype integer. for dtype in [np.int8, np.int16, np.int32, np.int64]: - seq_view = self.seq[np.array(indices, dtype=dtype)] - assert_true(seq_view is not self.seq) - assert_true(seq_view._data is self.seq._data) - assert_true(seq_view._offsets is not self.seq._offsets) - assert_true(seq_view._lengths is not self.seq._lengths) - assert_array_equal(seq_view._offsets, self.seq._offsets) - assert_array_equal(seq_view._lengths, self.seq._lengths) - for e1, e2 in zip_longest(seq_view, self.seq): - assert_array_equal(e1, e2) + seq_view = SEQ_DATA['seq'][np.array(indices, dtype=dtype)] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + # We took all elements so the view should match the original. + check_arr_seq(seq_view, SEQ_DATA['seq']) + + # Get multiple items out of order (creates a view). + SEQ_DATA['rng'].shuffle(indices) + seq_view = SEQ_DATA['seq'][indices] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + check_arr_seq(seq_view, np.asarray(SEQ_DATA['data'])[indices]) # Get slice (this will create a view). - seq_view = self.seq[::2] - assert_true(seq_view is not self.seq) - assert_true(seq_view._data is self.seq._data) - assert_array_equal(seq_view._offsets, self.seq._offsets[::2]) - assert_array_equal(seq_view._lengths, self.seq._lengths[::2]) - for i, e in enumerate(seq_view): - assert_array_equal(e, self.seq[i*2]) + seq_view = SEQ_DATA['seq'][::2] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + check_arr_seq(seq_view, SEQ_DATA['data'][::2]) # Use advanced indexing with ndarray of data type bool. - idx = np.array([False, True, True, False, True]) - seq_view = self.seq[idx] - assert_true(seq_view is not self.seq) - assert_true(seq_view._data is self.seq._data) - assert_array_equal(seq_view._offsets, - self.seq._offsets[idx]) - assert_array_equal(seq_view._lengths, - self.seq._lengths[idx]) - assert_array_equal(seq_view[0], self.seq[1]) - assert_array_equal(seq_view[1], self.seq[2]) - assert_array_equal(seq_view[2], self.seq[4]) + selection = np.array([False, True, True, False, True]) + seq_view = SEQ_DATA['seq'][selection] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + check_arr_seq(seq_view, np.asarray(SEQ_DATA['data'])[selection]) # Test invalid indexing - assert_raises(TypeError, self.seq.__getitem__, 'abc') + assert_raises(TypeError, SEQ_DATA['seq'].__getitem__, 'abc') def test_arraysequence_repr(self): # Test that calling repr on a ArraySequence object is not falling. - repr(self.seq) + repr(SEQ_DATA['seq']) def test_save_and_load_arraysequence(self): - # Test saving and loading an empty ArraySequence. with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: seq = ArraySequence() @@ -299,9 +246,7 @@ def test_save_and_load_arraysequence(self): # Test saving and loading a ArraySequence. with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: - rng = np.random.RandomState(42) - data = [rng.rand(rng.randint(10, 50), 3) for _ in range(10)] - seq = ArraySequence(data) + seq = SEQ_DATA['seq'] seq.save(f) f.seek(0, os.SEEK_SET) loaded_seq = ArraySequence.from_filename(f) From 6aa70b74eaf41403b9792b4f117dd62f832112ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 18 Feb 2016 23:32:31 -0500 Subject: [PATCH 082/135] Addressed @matthew-brett's comments. --- .../streamlines/tests/test_array_sequence.py | 2 +- nibabel/streamlines/tests/test_streamlines.py | 140 +++++++++--------- nibabel/streamlines/tests/test_tractogram.py | 20 ++- nibabel/streamlines/tractogram.py | 47 ++---- 4 files changed, 101 insertions(+), 108 deletions(-) diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index fe8e9da9da..245d2afd53 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -46,7 +46,7 @@ def check_arr_seq(seq, arrays): assert_equal(seq.common_shape, arrays[0].shape[1:]) assert_arrays_equal(seq, arrays) - # If seq is a view, there order of internal data is not guarantied. + # If seq is a view, then order of internal data is not guaranteed. if seq._is_view: # The only thing we can check is the _lengths. assert_array_equal(sorted(seq._lengths), sorted(lengths)) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 92893792ea..32f0e1a736 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -15,11 +15,63 @@ from .test_tractogram import assert_tractogram_equal from ..tractogram import Tractogram, LazyTractogram from ..tractogram_file import TractogramFile, ExtensionWarning -from ..tractogram import UsageWarning from .. import trk DATA_PATH = pjoin(os.path.dirname(__file__), 'data') +DATA = {} + + +def setup(): + global DATA + DATA['empty_filenames'] = [pjoin(DATA_PATH, "empty" + ext) + for ext in nib.streamlines.FORMATS.keys()] + DATA['simple_filenames'] = [pjoin(DATA_PATH, "simple" + ext) + for ext in nib.streamlines.FORMATS.keys()] + DATA['complex_filenames'] = [pjoin(DATA_PATH, "complex" + ext) + for ext in nib.streamlines.FORMATS.keys()] + + DATA['streamlines'] = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + fa = [np.array([[0.2]], dtype="f4"), + np.array([[0.3], + [0.4]], dtype="f4"), + np.array([[0.5], + [0.6], + [0.6], + [0.7], + [0.8]], dtype="f4")] + + colors = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + mean_curvature = [np.array([1.11], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11], dtype="f4")] + + mean_torsion = [np.array([1.22], dtype="f4"), + np.array([2.22], dtype="f4"), + np.array([3.22], dtype="f4")] + + mean_colors = [np.array([1, 0, 0], dtype="f4"), + np.array([0, 1, 0], dtype="f4"), + np.array([0, 0, 1], dtype="f4")] + + DATA['data_per_point'] = {'colors': colors, + 'fa': fa} + DATA['data_per_streamline'] = {'mean_curvature': mean_curvature, + 'mean_torsion': mean_torsion, + 'mean_colors': mean_colors} + + DATA['empty_tractogram'] = Tractogram() + DATA['simple_tractogram'] = Tractogram(DATA['streamlines']) + DATA['complex_tractogram'] = Tractogram(DATA['streamlines'], + DATA['data_per_streamline'], + DATA['data_per_point']) + def test_is_supported(): # Emtpy file/string @@ -101,58 +153,10 @@ def test_detect_format(): class TestLoadSave(unittest.TestCase): - def setUp(self): - self.empty_filenames = [pjoin(DATA_PATH, "empty" + ext) - for ext in nib.streamlines.FORMATS.keys()] - self.simple_filenames = [pjoin(DATA_PATH, "simple" + ext) - for ext in nib.streamlines.FORMATS.keys()] - self.complex_filenames = [pjoin(DATA_PATH, "complex" + ext) - for ext in nib.streamlines.FORMATS.keys()] - - self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] - - self.fa = [np.array([[0.2]], dtype="f4"), - np.array([[0.3], - [0.4]], dtype="f4"), - np.array([[0.5], - [0.6], - [0.6], - [0.7], - [0.8]], dtype="f4")] - - self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), - np.array([(0, 1, 0)]*2, dtype="f4"), - np.array([(0, 0, 1)]*5, dtype="f4")] - - self.mean_curvature = [np.array([1.11], dtype="f4"), - np.array([2.11], dtype="f4"), - np.array([3.11], dtype="f4")] - - self.mean_torsion = [np.array([1.22], dtype="f4"), - np.array([2.22], dtype="f4"), - np.array([3.22], dtype="f4")] - - self.mean_colors = [np.array([1, 0, 0], dtype="f4"), - np.array([0, 1, 0], dtype="f4"), - np.array([0, 0, 1], dtype="f4")] - - self.data_per_point = {'colors': self.colors, - 'fa': self.fa} - self.data_per_streamline = {'mean_curvature': self.mean_curvature, - 'mean_torsion': self.mean_torsion, - 'mean_colors': self.mean_colors} - - self.empty_tractogram = Tractogram() - self.simple_tractogram = Tractogram(self.streamlines) - self.complex_tractogram = Tractogram(self.streamlines, - self.data_per_streamline, - self.data_per_point) def test_load_empty_file(self): for lazy_load in [False, True]: - for empty_filename in self.empty_filenames: + for empty_filename in DATA['empty_filenames']: tfile = nib.streamlines.load(empty_filename, lazy_load=lazy_load) assert_true(isinstance(tfile, TractogramFile)) @@ -163,11 +167,11 @@ def test_load_empty_file(self): assert_true(type(tfile.tractogram), LazyTractogram) assert_tractogram_equal(tfile.tractogram, - self.empty_tractogram) + DATA['empty_tractogram']) def test_load_simple_file(self): for lazy_load in [False, True]: - for simple_filename in self.simple_filenames: + for simple_filename in DATA['simple_filenames']: tfile = nib.streamlines.load(simple_filename, lazy_load=lazy_load) assert_true(isinstance(tfile, TractogramFile)) @@ -178,11 +182,11 @@ def test_load_simple_file(self): assert_true(type(tfile.tractogram), LazyTractogram) assert_tractogram_equal(tfile.tractogram, - self.simple_tractogram) + DATA['simple_tractogram']) def test_load_complex_file(self): for lazy_load in [False, True]: - for complex_filename in self.complex_filenames: + for complex_filename in DATA['complex_filenames']: tfile = nib.streamlines.load(complex_filename, lazy_load=lazy_load) assert_true(isinstance(tfile, TractogramFile)) @@ -192,19 +196,19 @@ def test_load_complex_file(self): else: assert_true(type(tfile.tractogram), LazyTractogram) - tractogram = Tractogram(self.streamlines) + tractogram = Tractogram(DATA['streamlines']) if tfile.support_data_per_point(): - tractogram.data_per_point = self.data_per_point + tractogram.data_per_point = DATA['data_per_point'] if tfile.support_data_per_streamline(): - tractogram.data_per_streamline = self.data_per_streamline + tractogram.data_per_streamline = DATA['data_per_streamline'] assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_tractogram_file(self): - tractogram = Tractogram(self.streamlines) + tractogram = Tractogram(DATA['streamlines']) trk_file = trk.TrkFile(tractogram) # No need for keyword arguments. @@ -232,7 +236,7 @@ def test_save_empty_file(self): assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_simple_file(self): - tractogram = Tractogram(self.streamlines) + tractogram = Tractogram(DATA['streamlines']) for ext, cls in nib.streamlines.FORMATS.items(): with InTemporaryDirectory(): with open('streamlines' + ext, 'w+b') as f: @@ -241,9 +245,9 @@ def test_save_simple_file(self): assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_complex_file(self): - complex_tractogram = Tractogram(self.streamlines, - self.data_per_streamline, - self.data_per_point) + complex_tractogram = Tractogram(DATA['streamlines'], + DATA['data_per_streamline'], + DATA['data_per_point']) for ext, cls in nib.streamlines.FORMATS.items(): with InTemporaryDirectory(): @@ -255,18 +259,18 @@ def test_save_complex_file(self): # If streamlines format does not support saving data # per point or data per streamline, a warning message # should be issued. - if not (cls.support_data_per_point() - and cls.support_data_per_streamline()): + if not (cls.support_data_per_point() and + cls.support_data_per_streamline()): assert_equal(len(w), 1) - assert_true(issubclass(w[0].category, UsageWarning)) + assert_true(issubclass(w[0].category, Warning)) - tractogram = Tractogram(self.streamlines) + tractogram = Tractogram(DATA['streamlines']) if cls.support_data_per_point(): - tractogram.data_per_point = self.data_per_point + tractogram.data_per_point = DATA['data_per_point'] if cls.support_data_per_streamline(): - tractogram.data_per_streamline = self.data_per_streamline + tractogram.data_per_streamline = DATA['data_per_streamline'] tfile = nib.streamlines.load(f, lazy_load=False) assert_tractogram_equal(tfile.tractogram, tractogram) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 68c8f61438..97c139c996 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -9,7 +9,6 @@ from nibabel.externals.six.moves import zip from .. import tractogram as module_tractogram -from ..tractogram import UsageWarning from ..tractogram import TractogramItem, Tractogram, LazyTractogram @@ -198,6 +197,13 @@ def test_tractogram_creation(self): assert_equal(tractogram.data_per_point, {}) assert_true(check_iteration(tractogram)) + # Create a tractogram with streamlines and a given affine. + affine = np.diag([1, 2, 3, 1]) + tractogram = Tractogram(streamlines=self.streamlines, + affine_to_rasmm=affine) + assert_arrays_equal(tractogram.streamlines, self.streamlines) + assert_array_equal(tractogram.get_affine_to_rasmm(), affine) + # Create a tractogram with streamlines and other data. tractogram = Tractogram( self.streamlines, @@ -409,7 +415,7 @@ def test_tractogram_apply_affine(self): tractogram.streamlines): assert_array_almost_equal(s1, s2*scaling) - assert_array_equal(transformed_tractogram.affine_to_rasmm, + assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), np.dot(np.eye(4), np.linalg.inv(affine))) # Apply the affine to the streamlines in-place. @@ -421,9 +427,9 @@ def test_tractogram_apply_affine(self): self.streamlines): assert_array_almost_equal(s1, s2*scaling) - # Apply affine again and check the affine_to_rasmm property. + # Apply affine again and check the affine_to_rasmm. transformed_tractogram = tractogram.apply_affine(affine) - assert_array_equal(transformed_tractogram.affine_to_rasmm, + assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), np.dot(np.eye(4), np.dot(np.linalg.inv(affine), np.linalg.inv(affine)))) @@ -573,7 +579,7 @@ def test_lazy_tractogram_len(self): # New instances should still produce a warning message. assert_equal(len(tractogram), self.nb_streamlines) assert_equal(len(w), 2) - assert_true(issubclass(w[-1].category, UsageWarning)) + assert_true(issubclass(w[-1].category, Warning)) # Calling again 'len' again should *not* produce a warning. assert_equal(len(tractogram), self.nb_streamlines) @@ -613,9 +619,9 @@ def test_lazy_tractogram_apply_affine(self): for s1, s2 in zip(tractogram.streamlines, self.streamlines): assert_array_almost_equal(s1, s2*scaling) - # Apply affine again and check the affine_to_rasmm property. + # Apply affine again and check the affine_to_rasmm. transformed_tractogram = tractogram.apply_affine(affine) - assert_array_equal(transformed_tractogram.affine_to_rasmm, + assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), np.dot(np.eye(4), np.dot(np.linalg.inv(affine), np.linalg.inv(affine)))) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index f095c32c9a..1b23843c5b 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -1,3 +1,4 @@ +import copy import numpy as np import collections from warnings import warn @@ -7,10 +8,6 @@ from .array_sequence import ArraySequence -class UsageWarning(Warning): - pass - - class TractogramItem(object): """ Class containing information about one streamline. @@ -50,11 +47,6 @@ class Tractogram(object): Tractogram objects have three main properties: `streamlines`, `data_per_streamline` and `data_per_point`. - Attributes - ---------- - affine_to_rasmm : ndarray shape (4, 4) - Affine that brings the streamlines back to *RAS+* and *mm* space - where coordinate (0,0,0) refers to the center of the voxel. """ class DataDict(collections.MutableMapping): def __init__(self, tractogram, *args, **kwargs): @@ -85,7 +77,7 @@ def __len__(self): return len(self.store) class DataPerStreamlineDict(DataDict): - """ Internal dictionary that makes sure data are 2D array. """ + """ Dictionary that makes sure data are 2D array. """ def __setitem__(self, key, value): value = np.asarray(value) @@ -107,7 +99,7 @@ def __setitem__(self, key, value): self.store[key] = value class DataPerPointDict(DataDict): - """ Internal dictionary making sure data are :class:`ArraySequence` objects. """ + """ Dictionary making sure data are :class:`ArraySequence` objects. """ def __setitem__(self, key, value): value = ArraySequence(value) @@ -125,7 +117,8 @@ def __setitem__(self, key, value): def __init__(self, streamlines=None, data_per_streamline=None, - data_per_point=None): + data_per_point=None, + affine_to_rasmm=np.eye(4)): """ Parameters ---------- @@ -142,11 +135,15 @@ def __init__(self, streamlines=None, points for a particular streamline t and M is the number of scalars associated to each point (excluding the three coordinates). + affine_to_rasmm : ndarray of shape (4, 4) + Transformation matrix that brings the streamlines contained in + this tractogram to *RAS+* and *mm* space where coordinate (0,0,0) + refers to the center of the voxel. """ self.streamlines = streamlines self.data_per_streamline = data_per_streamline self.data_per_point = data_per_point - self._affine_to_rasmm = np.eye(4) + self._affine_to_rasmm = affine_to_rasmm @property def streamlines(self): @@ -173,9 +170,8 @@ def data_per_point(self): def data_per_point(self, value): self._data_per_point = Tractogram.DataPerPointDict(self, value) - @property - def affine_to_rasmm(self): - # Return a copy. User should use self.apply_affine` to modify it. + def get_affine_to_rasmm(self): + """ Returns the affine bringing this tractogram to RAS+mm. """ return self._affine_to_rasmm.copy() def __iter__(self): @@ -203,20 +199,7 @@ def __len__(self): def copy(self): """ Returns a copy of this :class:`Tractogram` object. """ - data_per_streamline = {} - for key in self.data_per_streamline: - data_per_streamline[key] = self.data_per_streamline[key].copy() - - data_per_point = {} - for key in self.data_per_point: - data_per_point[key] = self.data_per_point[key].copy() - - tractogram = Tractogram(self.streamlines.copy(), - data_per_streamline, - data_per_point) - - tractogram._affine_to_rasmm = self.affine_to_rasmm - return tractogram + return copy.deepcopy(self) def apply_affine(self, affine, lazy=False): """ Applies an affine transformation on the points of each streamline. @@ -359,7 +342,7 @@ def from_tractogram(cls, tractogram): data_per_point) lazy_tractogram._nb_streamlines = len(tractogram) - lazy_tractogram._affine_to_rasmm = tractogram.affine_to_rasmm + lazy_tractogram._affine_to_rasmm = tractogram.get_affine_to_rasmm() return lazy_tractogram @classmethod @@ -502,7 +485,7 @@ def __len__(self): " streamlines, you might want to set it beforehand via" " `self.header.nb_streamlines`." " Note this will consume any generators used to create this" - " `LazyTractogram` object.", UsageWarning) + " `LazyTractogram` object.", Warning) # Count the number of streamlines. self._nb_streamlines = sum(1 for _ in self.streamlines) From 1e395cc885005b4dedf459ce4043a44baf0d2304 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 19 Feb 2016 09:53:30 -0500 Subject: [PATCH 083/135] Refactored DataDict following @matthew-brett's suggestion so Tractogram can be pickle. --- nibabel/streamlines/tests/test_tractogram.py | 670 ++++++++----------- nibabel/streamlines/tractogram.py | 237 ++++--- 2 files changed, 409 insertions(+), 498 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 97c139c996..bf4b936fad 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -3,29 +3,126 @@ import warnings from nibabel.testing import assert_arrays_equal, check_iteration -from nibabel.testing import suppress_warnings, clear_and_catch_warnings +from nibabel.testing import clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal, assert_array_almost_equal from nibabel.externals.six.moves import zip from .. import tractogram as module_tractogram from ..tractogram import TractogramItem, Tractogram, LazyTractogram +from ..tractogram import DataPerStreamlineDict, DataPerPointDict, LazyDict + +DATA = {} + + +def setup(): + global DATA + DATA['streamlines'] = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + DATA['fa'] = [np.array([[0.2]], dtype="f4"), + np.array([[0.3], + [0.4]], dtype="f4"), + np.array([[0.5], + [0.6], + [0.6], + [0.7], + [0.8]], dtype="f4")] + + DATA['colors'] = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + DATA['mean_curvature'] = [np.array([1.11], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11], dtype="f4")] + + DATA['mean_torsion'] = [np.array([1.22], dtype="f4"), + np.array([2.22], dtype="f4"), + np.array([3.22], dtype="f4")] + + DATA['mean_colors'] = [np.array([1, 0, 0], dtype="f4"), + np.array([0, 1, 0], dtype="f4"), + np.array([0, 0, 1], dtype="f4")] + + DATA['data_per_point'] = {'colors': DATA['colors'], + 'fa': DATA['fa']} + DATA['data_per_streamline'] = {'mean_curvature': DATA['mean_curvature'], + 'mean_torsion': DATA['mean_torsion'], + 'mean_colors': DATA['mean_colors']} + + DATA['empty_tractogram'] = Tractogram() + DATA['simple_tractogram'] = Tractogram(DATA['streamlines']) + DATA['tractogram'] = Tractogram(DATA['streamlines'], + DATA['data_per_streamline'], + DATA['data_per_point']) + + DATA['streamlines_func'] = lambda: (e for e in DATA['streamlines']) + fa_func = lambda: (e for e in DATA['fa']) + colors_func = lambda: (e for e in DATA['colors']) + mean_curvature_func = lambda: (e for e in DATA['mean_curvature']) + mean_torsion_func = lambda: (e for e in DATA['mean_torsion']) + mean_colors_func = lambda: (e for e in DATA['mean_colors']) + + DATA['data_per_point_func'] = {'colors': colors_func, + 'fa': fa_func} + DATA['data_per_streamline_func'] = {'mean_curvature': mean_curvature_func, + 'mean_torsion': mean_torsion_func, + 'mean_colors': mean_colors_func} + + DATA['lazy_tractogram'] = LazyTractogram(DATA['streamlines_func'], + DATA['data_per_streamline_func'], + DATA['data_per_point_func']) + + +def check_tractogram_item(tractogram_item, + streamline, + data_for_streamline={}, + data_for_points={}): + + assert_array_equal(tractogram_item.streamline, streamline) + + assert_equal(len(tractogram_item.data_for_streamline), + len(data_for_streamline)) + for key in data_for_streamline.keys(): + assert_array_equal(tractogram_item.data_for_streamline[key], + data_for_streamline[key]) + + assert_equal(len(tractogram_item.data_for_points), len(data_for_points)) + for key in data_for_points.keys(): + assert_arrays_equal(tractogram_item.data_for_points[key], + data_for_points[key]) + + +def assert_tractogram_item_equal(t1, t2): + check_tractogram_item(t1, t2.streamline, + t2.data_for_streamline, t2.data_for_points) + + +def check_tractogram(tractogram, + streamlines=[], + data_per_streamline={}, + data_per_point={}): + streamlines = list(streamlines) + assert_equal(len(tractogram), len(streamlines)) + assert_arrays_equal(tractogram.streamlines, streamlines) + assert_true(check_iteration(tractogram)) + + assert_equal(len(tractogram.data_per_streamline), len(data_per_streamline)) + for key in data_per_streamline.keys(): + assert_arrays_equal(tractogram.data_per_streamline[key], + data_per_streamline[key]) + + assert_equal(len(tractogram.data_per_point), len(data_per_point)) + for key in data_per_point.keys(): + assert_arrays_equal(tractogram.data_per_point[key], + data_per_point[key]) def assert_tractogram_equal(t1, t2): - assert_true(check_iteration(t1)) - assert_equal(len(t1), len(t2)) - assert_arrays_equal(t1.streamlines, t2.streamlines) - - assert_equal(len(t1.data_per_streamline), len(t2.data_per_streamline)) - for key in t1.data_per_streamline.keys(): - assert_arrays_equal(t1.data_per_streamline[key], - t2.data_per_streamline[key]) - - assert_equal(len(t1.data_per_point), len(t2.data_per_point)) - for key in t1.data_per_point.keys(): - assert_arrays_equal(t1.data_per_point[key], - t2.data_per_point[key]) + check_tractogram(t1, t2.streamlines, + t2.data_per_streamline, t2.data_per_point) class TestTractogramItem(unittest.TestCase): @@ -57,46 +154,24 @@ def test_creating_tractogram_item(self): class TestTractogramDataDict(unittest.TestCase): - def setUp(self): - self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] - - self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), - np.array([(0, 1, 0)]*2, dtype="f4"), - np.array([(0, 0, 1)]*5, dtype="f4")] - - self.mean_curvature = np.array([1.11, 2.11, 3.11], dtype="f4") - self.mean_color = np.array([[0, 1, 0], - [0, 0, 1], - [1, 0, 0]], dtype="f4") - - self.nb_streamlines = len(self.streamlines) - - # Create a tractogram with streamlines and other data. - self.tractogram = Tractogram( - self.streamlines, - data_per_streamline={'mean_curvature': self.mean_curvature, - 'mean_color': self.mean_color}, - data_per_point={'colors': self.colors}) - def test_datadict_creation(self): # Create a DataPerStreamlineDict object using another # DataPerStreamlineDict object. - data_per_streamline = self.tractogram.data_per_streamline - data_dict = Tractogram.DataPerStreamlineDict(self.tractogram, - data_per_streamline) + data_per_streamline = DATA['tractogram'].data_per_streamline + data_dict = DataPerStreamlineDict(DATA['tractogram'], + data_per_streamline) assert_equal(data_dict.keys(), data_per_streamline.keys()) for k in data_dict.keys(): assert_array_equal(data_dict[k], data_per_streamline[k]) del data_dict['mean_curvature'] - assert_equal(len(data_dict), len(self.tractogram.data_per_streamline)-1) + assert_equal(len(data_dict), + len(DATA['tractogram'].data_per_streamline)-1) # Create a DataPerStreamlineDict object using an existing dict object. - data_per_streamline = self.tractogram.data_per_streamline.store - data_dict = Tractogram.DataPerStreamlineDict(self.tractogram, - data_per_streamline) + data_per_streamline = DATA['tractogram'].data_per_streamline.store + data_dict = DataPerStreamlineDict(DATA['tractogram'], + data_per_streamline) assert_equal(data_dict.keys(), data_per_streamline.keys()) for k in data_dict.keys(): assert_array_equal(data_dict[k], data_per_streamline[k]) @@ -105,9 +180,9 @@ def test_datadict_creation(self): assert_equal(len(data_dict), len(data_per_streamline)-1) # Create a DataPerStreamlineDict object using keyword arguments. - data_per_streamline = self.tractogram.data_per_streamline.store - data_dict = Tractogram.DataPerStreamlineDict(self.tractogram, - **data_per_streamline) + data_per_streamline = DATA['tractogram'].data_per_streamline.store + data_dict = DataPerStreamlineDict(DATA['tractogram'], + **data_per_streamline) assert_equal(data_dict.keys(), data_per_streamline.keys()) for k in data_dict.keys(): assert_array_equal(data_dict[k], data_per_streamline[k]) @@ -115,112 +190,56 @@ def test_datadict_creation(self): del data_dict['mean_curvature'] assert_equal(len(data_dict), len(data_per_streamline)-1) + def test_getitem(self): + data_dict = DataPerPointDict(DATA['tractogram'], + DATA['data_per_point']) -class TestTractogramLazyDict(unittest.TestCase): - - def setUp(self): - self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] - - self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), - np.array([(0, 1, 0)]*2, dtype="f4"), - np.array([(0, 0, 1)]*5, dtype="f4")] - - self.mean_curvature = np.array([1.11, 2.11, 3.11], dtype="f4") - self.mean_color = np.array([[0, 1, 0], - [0, 0, 1], - [1, 0, 0]], dtype="f4") + assert_true('fa' in data_dict) + assert_arrays_equal(data_dict['fa'], DATA['fa']) + assert_arrays_equal(data_dict[::2]['fa'], DATA['fa'][::2]) + assert_arrays_equal(data_dict[::-1]['fa'], DATA['fa'][::-1]) + assert_arrays_equal(data_dict[-1]['fa'], DATA['fa'][-1]) + assert_raises(KeyError, data_dict.__getitem__, 'invalid') - self.nb_streamlines = len(self.streamlines) - self.colors_func = lambda: (x for x in self.colors) - self.mean_curvature_func = lambda: (x for x in self.mean_curvature) - self.mean_color_func = lambda: (x for x in self.mean_color) - - streamlines = lambda: (x for x in self.streamlines) - data_per_point = {"colors": self.colors_func} - data_per_streamline = {'mean_curvature': self.mean_curvature_func, - 'mean_color': self.mean_color_func} - - # Create a tractogram with streamlines and other data. - self.tractogram = LazyTractogram( - streamlines, - data_per_streamline=data_per_streamline, - data_per_point=data_per_point) +class TestTractogramLazyDict(unittest.TestCase): def test_lazydict_creation(self): - # Create a DataPerStreamlineDict object using another - # DataPerStreamlineDict object. - data_per_streamline = self.tractogram.data_per_streamline - data_dict = LazyTractogram.LazyDict(data_per_streamline) - assert_equal(data_dict.keys(), data_per_streamline.keys()) + data_dict = LazyDict(None, DATA['data_per_streamline_func']) + assert_equal(data_dict.keys(), DATA['data_per_streamline_func'].keys()) for k in data_dict.keys(): - assert_array_equal(list(data_dict[k]), list(data_per_streamline[k])) + assert_array_equal(list(data_dict[k]), + list(DATA['data_per_streamline'][k])) - del data_dict['mean_curvature'] - assert_equal(len(data_dict), len(self.tractogram.data_per_streamline)-1) + assert_equal(len(data_dict), + len(DATA['data_per_streamline_func'])) class TestTractogram(unittest.TestCase): - def setUp(self): - self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] - - self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), - np.array([(0, 1, 0)]*2, dtype="f4"), - np.array([(0, 0, 1)]*5, dtype="f4")] - - self.mean_curvature = np.array([1.11, 2.11, 3.11], dtype="f4") - self.mean_color = np.array([[0, 1, 0], - [0, 0, 1], - [1, 0, 0]], dtype="f4") - - self.nb_streamlines = len(self.streamlines) - def test_tractogram_creation(self): # Create an empty tractogram. tractogram = Tractogram() - assert_equal(len(tractogram), 0) - assert_arrays_equal(tractogram.streamlines, []) - assert_equal(tractogram.data_per_streamline, {}) - assert_equal(tractogram.data_per_point, {}) - assert_true(check_iteration(tractogram)) + check_tractogram(tractogram) # Create a tractogram with only streamlines - tractogram = Tractogram(streamlines=self.streamlines) - assert_equal(len(tractogram), len(self.streamlines)) - assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_equal(tractogram.data_per_streamline, {}) - assert_equal(tractogram.data_per_point, {}) - assert_true(check_iteration(tractogram)) + tractogram = Tractogram(streamlines=DATA['streamlines']) + check_tractogram(tractogram, DATA['streamlines']) - # Create a tractogram with streamlines and a given affine. + # Create a tractogram with a given affine_to_rasmm. affine = np.diag([1, 2, 3, 1]) - tractogram = Tractogram(streamlines=self.streamlines, - affine_to_rasmm=affine) - assert_arrays_equal(tractogram.streamlines, self.streamlines) + tractogram = Tractogram(affine_to_rasmm=affine) assert_array_equal(tractogram.get_affine_to_rasmm(), affine) # Create a tractogram with streamlines and other data. - tractogram = Tractogram( - self.streamlines, - data_per_streamline={'mean_curvature': self.mean_curvature, - 'mean_color': self.mean_color}, - data_per_point={'colors': self.colors}) - - assert_equal(len(tractogram), len(self.streamlines)) - assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.data_per_streamline['mean_curvature'], - self.mean_curvature) - assert_arrays_equal(tractogram.data_per_streamline['mean_color'], - self.mean_color) - assert_arrays_equal(tractogram.data_per_point['colors'], - self.colors) + tractogram = Tractogram(DATA['streamlines'], + DATA['data_per_streamline'], + DATA['data_per_point']) - assert_true(check_iteration(tractogram)) + check_tractogram(tractogram, + DATA['streamlines'], + DATA['data_per_streamline'], + DATA['data_per_point']) # Inconsistent number of scalars between streamlines wrong_data = [[(1, 0, 0)]*1, @@ -228,7 +247,7 @@ def test_tractogram_creation(self): [(0, 0, 1)]*5] data_per_point = {'wrong_data': wrong_data} - assert_raises(ValueError, Tractogram, self.streamlines, + assert_raises(ValueError, Tractogram, DATA['streamlines'], data_per_point=data_per_point) # Inconsistent number of scalars between streamlines @@ -237,118 +256,69 @@ def test_tractogram_creation(self): [(0, 0, 1)]*5] data_per_point = {'wrong_data': wrong_data} - assert_raises(ValueError, Tractogram, self.streamlines, + assert_raises(ValueError, Tractogram, DATA['streamlines'], data_per_point=data_per_point) def test_tractogram_getitem(self): - # Tractogram with only streamlines - tractogram = Tractogram(streamlines=self.streamlines) + # Retrieve TractogramItem by their index. + for i, t in enumerate(DATA['tractogram']): + assert_tractogram_item_equal(DATA['tractogram'][i], t) - selected_tractogram = tractogram[::2] - assert_equal(len(selected_tractogram), (len(self.streamlines)+1)//2) + # Get one TractogramItem out of two. + tractogram_view = DATA['simple_tractogram'][::2] + check_tractogram(tractogram_view, DATA['streamlines'][::2]) - assert_arrays_equal(selected_tractogram.streamlines, - self.streamlines[::2]) - assert_equal(tractogram.data_per_streamline, {}) - assert_equal(tractogram.data_per_point, {}) - - # Create a tractogram with streamlines and other data. - tractogram = Tractogram( - self.streamlines, - data_per_streamline={'mean_curvature': self.mean_curvature, - 'mean_color': self.mean_color}, - data_per_point={'colors': self.colors}) - - # Retrieve tractogram by their index - for i, t in enumerate(tractogram): - assert_array_equal(t.streamline, tractogram[i].streamline) - assert_array_equal(t.data_for_points['colors'], - tractogram[i].data_for_points['colors']) - - assert_array_equal(t.data_for_streamline['mean_curvature'], - tractogram[i].data_for_streamline['mean_curvature']) - - assert_array_equal(t.data_for_streamline['mean_color'], - tractogram[i].data_for_streamline['mean_color']) - - # Use slicing - r_tractogram = tractogram[::-1] - assert_arrays_equal(r_tractogram.streamlines, self.streamlines[::-1]) - - assert_arrays_equal(r_tractogram.data_per_streamline['mean_curvature'], - self.mean_curvature[::-1]) - assert_arrays_equal(r_tractogram.data_per_streamline['mean_color'], - self.mean_color[::-1]) - assert_arrays_equal(r_tractogram.data_per_point['colors'], - self.colors[::-1]) + # Use slicing. + r_tractogram = DATA['tractogram'][::-1] + check_tractogram(r_tractogram, + DATA['streamlines'][::-1], + DATA['tractogram'].data_per_streamline[::-1], + DATA['tractogram'].data_per_point[::-1]) def test_tractogram_add_new_data(self): # Tractogram with only streamlines - tractogram = Tractogram(streamlines=self.streamlines) - - tractogram.data_per_streamline['mean_curvature'] = self.mean_curvature - tractogram.data_per_streamline['mean_color'] = self.mean_color - tractogram.data_per_point['colors'] = self.colors - - # Retrieve tractogram by their index - for i, t in enumerate(tractogram): - assert_array_equal(t.streamline, tractogram[i].streamline) - assert_array_equal(t.data_for_points['colors'], - tractogram[i].data_for_points['colors']) - - assert_array_equal(t.data_for_streamline['mean_curvature'], - tractogram[i].data_for_streamline['mean_curvature']) - - assert_array_equal(t.data_for_streamline['mean_color'], - tractogram[i].data_for_streamline['mean_color']) - - # Use slicing - r_tractogram = tractogram[::-1] - assert_arrays_equal(r_tractogram.streamlines, self.streamlines[::-1]) - - assert_arrays_equal(r_tractogram.data_per_streamline['mean_curvature'], - self.mean_curvature[::-1]) - assert_arrays_equal(r_tractogram.data_per_streamline['mean_color'], - self.mean_color[::-1]) - assert_arrays_equal(r_tractogram.data_per_point['colors'], - self.colors[::-1]) + t = DATA['simple_tractogram'].copy() + t.data_per_point['fa'] = DATA['fa'] + t.data_per_point['colors'] = DATA['colors'] + t.data_per_streamline['mean_curvature'] = DATA['mean_curvature'] + t.data_per_streamline['mean_torsion'] = DATA['mean_torsion'] + t.data_per_streamline['mean_colors'] = DATA['mean_colors'] + assert_tractogram_equal(t, DATA['tractogram']) + + # Retrieve tractogram by their index. + for i, item in enumerate(t): + assert_tractogram_item_equal(t[i], item) + + # Use slicing. + r_tractogram = t[::-1] + check_tractogram(r_tractogram, + t.streamlines[::-1], + t.data_per_streamline[::-1], + t.data_per_point[::-1]) def test_tractogram_copy(self): - # Create a tractogram with streamlines and other data. - tractogram1 = Tractogram( - self.streamlines, - data_per_streamline={'mean_curvature': self.mean_curvature, - 'mean_color': self.mean_color}, - data_per_point={'colors': self.colors}) - - # Create a copy of the tractogram. - tractogram2 = tractogram1.copy() + # Create a copy of a tractogram. + tractogram = DATA['tractogram'].copy() # Check we copied the data and not simply created new references. - assert_true(tractogram1 is not tractogram2) - assert_true(tractogram1.streamlines is not tractogram2.streamlines) - assert_true(tractogram1.data_per_streamline - is not tractogram2.data_per_streamline) - assert_true(tractogram1.data_per_streamline['mean_curvature'] - is not tractogram2.data_per_streamline['mean_curvature']) - assert_true(tractogram1.data_per_streamline['mean_color'] - is not tractogram2.data_per_streamline['mean_color']) - assert_true(tractogram1.data_per_point - is not tractogram2.data_per_point) - assert_true(tractogram1.data_per_point['colors'] - is not tractogram2.data_per_point['colors']) + assert_true(tractogram is not DATA['tractogram']) + assert_true(tractogram.streamlines + is not DATA['tractogram'].streamlines) + assert_true(tractogram.data_per_streamline + is not DATA['tractogram'].data_per_streamline) + assert_true(tractogram.data_per_point + is not DATA['tractogram'].data_per_point) - # Check the data are the equivalent. - assert_true(check_iteration(tractogram2)) - assert_equal(len(tractogram1), len(tractogram2)) - assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) - assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) - assert_arrays_equal(tractogram1.data_per_streamline['mean_curvature'], - tractogram2.data_per_streamline['mean_curvature']) - assert_arrays_equal(tractogram1.data_per_streamline['mean_color'], - tractogram2.data_per_streamline['mean_color']) - assert_arrays_equal(tractogram1.data_per_point['colors'], - tractogram2.data_per_point['colors']) + for key in tractogram.data_per_streamline: + assert_true(tractogram.data_per_streamline[key] + is not DATA['tractogram'].data_per_streamline[key]) + + for key in tractogram.data_per_point: + assert_true(tractogram.data_per_point[key] + is not DATA['tractogram'].data_per_point[key]) + + # Check the values of the data are the same. + assert_tractogram_equal(tractogram, DATA['tractogram']) def test_creating_invalid_tractogram(self): # Not enough data_per_point for all the points of all streamlines. @@ -356,14 +326,14 @@ def test_creating_invalid_tractogram(self): [(0, 1, 0)]*2, [(0, 0, 1)]*3] # Last streamlines has 5 points. - assert_raises(ValueError, Tractogram, self.streamlines, + assert_raises(ValueError, Tractogram, DATA['streamlines'], data_per_point={'scalars': scalars}) # Not enough data_per_streamline for all streamlines. properties = [np.array([1.11, 1.22], dtype="f4"), np.array([3.11, 3.22], dtype="f4")] - assert_raises(ValueError, Tractogram, self.streamlines, + assert_raises(ValueError, Tractogram, DATA['streamlines'], data_per_streamline={'properties': properties}) # Inconsistent dimension for a data_per_point. @@ -371,7 +341,7 @@ def test_creating_invalid_tractogram(self): [(0, 1)]*2, [(0, 0, 1)]*5] - assert_raises(ValueError, Tractogram, self.streamlines, + assert_raises(ValueError, Tractogram, DATA['streamlines'], data_per_point={'scalars': scalars}) # Inconsistent dimension for a data_per_streamline. @@ -379,7 +349,7 @@ def test_creating_invalid_tractogram(self): [2.11], [3.11, 3.22]] - assert_raises(ValueError, Tractogram, self.streamlines, + assert_raises(ValueError, Tractogram, DATA['streamlines'], data_per_streamline={'properties': properties}) # Too many dimension for a data_per_streamline. @@ -387,17 +357,11 @@ def test_creating_invalid_tractogram(self): np.array([[2.11], [2.22]], dtype="f4"), np.array([[3.11], [3.22]], dtype="f4")] - assert_raises(ValueError, Tractogram, self.streamlines, + assert_raises(ValueError, Tractogram, DATA['streamlines'], data_per_streamline={'properties': properties}) def test_tractogram_apply_affine(self): - # Create a tractogram with streamlines and other data. - tractogram = Tractogram( - self.streamlines, - data_per_streamline={'mean_curvature': self.mean_curvature, - 'mean_color': self.mean_color}, - data_per_point={'colors': self.colors}) - + tractogram = DATA['tractogram'].copy() affine = np.eye(4) scaling = np.array((1, 2, 3), dtype=float) affine[range(3), range(3)] = scaling @@ -406,9 +370,9 @@ def test_tractogram_apply_affine(self): transformed_tractogram = tractogram.apply_affine(affine, lazy=True) assert_true(type(transformed_tractogram) is LazyTractogram) assert_true(check_iteration(transformed_tractogram)) - assert_equal(len(transformed_tractogram), len(self.streamlines)) + assert_equal(len(transformed_tractogram), len(DATA['streamlines'])) for s1, s2 in zip(transformed_tractogram.streamlines, - self.streamlines): + DATA['streamlines']): assert_array_almost_equal(s1, s2*scaling) for s1, s2 in zip(transformed_tractogram.streamlines, @@ -422,9 +386,9 @@ def test_tractogram_apply_affine(self): transformed_tractogram = tractogram.apply_affine(affine) assert_true(transformed_tractogram is tractogram) assert_true(check_iteration(transformed_tractogram)) - assert_equal(len(transformed_tractogram), len(self.streamlines)) + assert_equal(len(transformed_tractogram), len(DATA['streamlines'])) for s1, s2 in zip(transformed_tractogram.streamlines, - self.streamlines): + DATA['streamlines']): assert_array_almost_equal(s1, s2*scaling) # Apply affine again and check the affine_to_rasmm. @@ -436,187 +400,121 @@ def test_tractogram_apply_affine(self): class TestLazyTractogram(unittest.TestCase): - def setUp(self): - self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] - - self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), - np.array([(0, 1, 0)]*2, dtype="f4"), - np.array([(0, 0, 1)]*5, dtype="f4")] - - self.mean_curvature = np.array([1.11, 2.11, 3.11], dtype="f4") - self.mean_color = np.array([[0, 1, 0], - [0, 0, 1], - [1, 0, 0]], dtype="f4") - - self.nb_streamlines = len(self.streamlines) - - self.colors_func = lambda: (x for x in self.colors) - self.mean_curvature_func = lambda: (x for x in self.mean_curvature) - self.mean_color_func = lambda: (x for x in self.mean_color) - def test_lazy_tractogram_creation(self): # To create tractogram from arrays use `Tractogram`. - assert_raises(TypeError, LazyTractogram, self.streamlines) + assert_raises(TypeError, LazyTractogram, DATA['streamlines']) # Streamlines and other data as generators - streamlines = (x for x in self.streamlines) - data_per_point = {"colors": (x for x in self.colors)} - data_per_streamline = {'mean_curv': (x for x in self.mean_curvature), - 'mean_color': (x for x in self.mean_color)} + streamlines = (x for x in DATA['streamlines']) + data_per_point = {"colors": (x for x in DATA['colors'])} + data_per_streamline = {'mean_torsion': (x for x in DATA['mean_torsion']), + 'mean_colors': (x for x in DATA['mean_colors'])} # Creating LazyTractogram with generators is not allowed as # generators get exhausted and are not reusable unlike coroutines. assert_raises(TypeError, LazyTractogram, streamlines) assert_raises(TypeError, LazyTractogram, data_per_streamline=data_per_streamline) - assert_raises(TypeError, LazyTractogram, self.streamlines, + assert_raises(TypeError, LazyTractogram, DATA['streamlines'], data_per_point=data_per_point) # Empty `LazyTractogram` tractogram = LazyTractogram() - assert_true(check_iteration(tractogram)) - assert_equal(len(tractogram), 0) - assert_arrays_equal(tractogram.streamlines, []) - assert_equal(tractogram.data_per_point, {}) - assert_equal(tractogram.data_per_streamline, {}) + check_tractogram(tractogram) # Create tractogram with streamlines and other data - streamlines = lambda: (x for x in self.streamlines) - data_per_point = {"colors": self.colors_func} - data_per_streamline = {'mean_curv': self.mean_curvature_func, - 'mean_color': self.mean_color_func} - - tractogram = LazyTractogram(streamlines, - data_per_streamline=data_per_streamline, - data_per_point=data_per_point) + tractogram = LazyTractogram(DATA['streamlines_func'], + DATA['data_per_streamline_func'], + DATA['data_per_point_func']) assert_true(check_iteration(tractogram)) - assert_equal(len(tractogram), self.nb_streamlines) + assert_equal(len(tractogram), len(DATA['streamlines'])) # Coroutines get re-called and creates new iterators. for i in range(2): - assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.data_per_streamline['mean_curv'], - self.mean_curvature) - assert_arrays_equal(tractogram.data_per_streamline['mean_color'], - self.mean_color) - assert_arrays_equal(tractogram.data_per_point['colors'], - self.colors) + assert_tractogram_equal(tractogram, DATA['tractogram']) def test_lazy_tractogram_create_from(self): - # Create `LazyTractogram` from a coroutine yielding nothing (i.e empty). + # Create an empty `LazyTractogram` yielding nothing. _empty_data_gen = lambda: iter([]) tractogram = LazyTractogram.create_from(_empty_data_gen) - assert_true(check_iteration(tractogram)) - assert_equal(len(tractogram), 0) - assert_arrays_equal(tractogram.streamlines, []) - assert_equal(tractogram.data_per_point, {}) - assert_equal(tractogram.data_per_streamline, {}) + check_tractogram(tractogram) # Create `LazyTractogram` from a coroutine yielding TractogramItem + data = [DATA['streamlines'], DATA['fa'], DATA['colors'], + DATA['mean_curvature'], DATA['mean_torsion'], + DATA['mean_colors']] + def _data_gen(): - for d in zip(self.streamlines, self.colors, - self.mean_curvature, self.mean_color): - data_for_points = {'colors': d[1]} - data_for_streamline = {'mean_curv': d[2], - 'mean_color': d[3]} - yield TractogramItem(d[0], data_for_streamline, data_for_points) + for d in zip(*data): + data_for_points = {'fa': d[1], + 'colors': d[2]} + data_for_streamline = {'mean_curvature': d[3], + 'mean_torsion': d[4], + 'mean_colors': d[5]} + yield TractogramItem(d[0], + data_for_streamline, + data_for_points) tractogram = LazyTractogram.create_from(_data_gen) - assert_true(check_iteration(tractogram)) - assert_equal(len(tractogram), self.nb_streamlines) - assert_arrays_equal(tractogram.streamlines, self.streamlines) - assert_arrays_equal(tractogram.data_per_streamline['mean_curv'], - self.mean_curvature) - assert_arrays_equal(tractogram.data_per_streamline['mean_color'], - self.mean_color) - assert_arrays_equal(tractogram.data_per_point['colors'], - self.colors) + assert_tractogram_equal(tractogram, DATA['tractogram']) # Creating a LazyTractogram from not a corouting should raise an error. assert_raises(TypeError, LazyTractogram.create_from, _data_gen()) def test_lazy_tractogram_getitem(self): - streamlines = lambda: (x for x in self.streamlines) - data_per_point = {"colors": self.colors_func} - data_per_streamline = {'mean_curv': self.mean_curvature_func, - 'mean_color': self.mean_color_func} - - # By default, `LazyTractogram` object does not support indexing. - tractogram = LazyTractogram(streamlines, - data_per_streamline=data_per_streamline, - data_per_point=data_per_point) - assert_raises(NotImplementedError, tractogram.__getitem__, 0) + assert_raises(NotImplementedError, + DATA['lazy_tractogram'].__getitem__, 0) def test_lazy_tractogram_len(self): - streamlines = lambda: (x for x in self.streamlines) - data_per_point = {"colors": self.colors_func} - data_per_streamline = {'mean_curv': self.mean_curvature_func, - 'mean_color': self.mean_color_func} - modules = [module_tractogram] # Modules for which to catch warnings. with clear_and_catch_warnings(record=True, modules=modules) as w: warnings.simplefilter("always") # Always trigger warnings. # Calling `len` will create new generators each time. - tractogram = LazyTractogram(streamlines, - data_per_streamline=data_per_streamline, - data_per_point=data_per_point) + tractogram = LazyTractogram(DATA['streamlines_func']) assert_true(tractogram._nb_streamlines is None) # This should produce a warning message. - assert_equal(len(tractogram), self.nb_streamlines) - assert_equal(tractogram._nb_streamlines, self.nb_streamlines) + assert_equal(len(tractogram), len(DATA['streamlines'])) + assert_equal(tractogram._nb_streamlines, len(DATA['streamlines'])) assert_equal(len(w), 1) - tractogram = LazyTractogram(streamlines, - data_per_streamline=data_per_streamline, - data_per_point=data_per_point) + tractogram = LazyTractogram(DATA['streamlines_func']) # New instances should still produce a warning message. - assert_equal(len(tractogram), self.nb_streamlines) + assert_equal(len(tractogram), len(DATA['streamlines'])) assert_equal(len(w), 2) assert_true(issubclass(w[-1].category, Warning)) # Calling again 'len' again should *not* produce a warning. - assert_equal(len(tractogram), self.nb_streamlines) + assert_equal(len(tractogram), len(DATA['streamlines'])) assert_equal(len(w), 2) with clear_and_catch_warnings(record=True, modules=modules) as w: # Once we iterated through the tractogram, we know the length. - tractogram = LazyTractogram(streamlines, - data_per_streamline=data_per_streamline, - data_per_point=data_per_point) + tractogram = LazyTractogram(DATA['streamlines_func']) assert_true(tractogram._nb_streamlines is None) check_iteration(tractogram) # Force iteration through tractogram. - assert_equal(tractogram._nb_streamlines, len(self.streamlines)) + assert_equal(tractogram._nb_streamlines, len(DATA['streamlines'])) # This should *not* produce a warning. - assert_equal(len(tractogram), len(self.streamlines)) + assert_equal(len(tractogram), len(DATA['streamlines'])) assert_equal(len(w), 0) def test_lazy_tractogram_apply_affine(self): - streamlines = lambda: (x for x in self.streamlines) - data_per_point = {"colors": self.colors_func} - data_per_streamline = {'mean_curv': self.mean_curvature_func, - 'mean_color': self.mean_color_func} - affine = np.eye(4) scaling = np.array((1, 2, 3), dtype=float) affine[range(3), range(3)] = scaling - tractogram = LazyTractogram(streamlines, - data_per_streamline=data_per_streamline, - data_per_point=data_per_point) + tractogram = DATA['lazy_tractogram'].copy() tractogram.apply_affine(affine) assert_true(check_iteration(tractogram)) - assert_equal(len(tractogram), len(self.streamlines)) - for s1, s2 in zip(tractogram.streamlines, self.streamlines): + assert_equal(len(tractogram), len(DATA['streamlines'])) + for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): assert_array_almost_equal(s1, s2*scaling) # Apply affine again and check the affine_to_rasmm. @@ -626,55 +524,37 @@ def test_lazy_tractogram_apply_affine(self): np.linalg.inv(affine)))) def test_lazy_tractogram_copy(self): - # Create tractogram with streamlines and other data - streamlines = lambda: (x for x in self.streamlines) - data_per_point = {"colors": self.colors_func} - data_per_streamline = {'mean_curv': self.mean_curvature_func, - 'mean_color': self.mean_color_func} - - tractogram1 = LazyTractogram(streamlines, - data_per_streamline=data_per_streamline, - data_per_point=data_per_point) - assert_true(check_iteration(tractogram1)) # Implicitly set _nb_streamlines. - - # Create a copy of the tractogram. - tractogram2 = tractogram1.copy() + # Create a copy of the lazy tractogram. + tractogram = DATA['lazy_tractogram'].copy() # Check we copied the data and not simply created new references. - assert_true(tractogram1 is not tractogram2) + assert_true(tractogram is not DATA['lazy_tractogram']) # When copying LazyTractogram, coroutines generating streamlines should # be the same. - assert_true(tractogram1._streamlines is tractogram2._streamlines) + assert_true(tractogram._streamlines + is DATA['lazy_tractogram']._streamlines) # Copying LazyTractogram, creates new internal LazyDict objects, # but coroutines contained in it should be the same. - assert_true(tractogram1._data_per_streamline - is not tractogram2._data_per_streamline) - assert_true(tractogram1.data_per_streamline.store['mean_curv'] - is tractogram2.data_per_streamline.store['mean_curv']) - assert_true(tractogram1.data_per_streamline.store['mean_color'] - is tractogram2.data_per_streamline.store['mean_color']) - assert_true(tractogram1._data_per_point - is not tractogram2._data_per_point) - assert_true(tractogram1.data_per_point.store['colors'] - is tractogram2.data_per_point.store['colors']) + assert_true(tractogram._data_per_streamline + is not DATA['lazy_tractogram']._data_per_streamline) + assert_true(tractogram._data_per_point + is not DATA['lazy_tractogram']._data_per_point) + + for key in tractogram.data_per_streamline: + assert_true(tractogram.data_per_streamline.store[key] + is DATA['lazy_tractogram'].data_per_streamline.store[key]) + + for key in tractogram.data_per_point: + assert_true(tractogram.data_per_point.store[key] + is DATA['lazy_tractogram'].data_per_point.store[key]) # The affine should be a copy. - assert_true(tractogram1._affine_to_apply - is not tractogram2._affine_to_apply) - assert_array_equal(tractogram1._affine_to_apply, - tractogram2._affine_to_apply) + assert_true(tractogram._affine_to_apply + is not DATA['lazy_tractogram']._affine_to_apply) + assert_array_equal(tractogram._affine_to_apply, + DATA['lazy_tractogram']._affine_to_apply) # Check the data are the equivalent. - assert_equal(tractogram1._nb_streamlines, tractogram2._nb_streamlines) - assert_true(check_iteration(tractogram2)) - assert_equal(len(tractogram1), len(tractogram2)) - assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) - assert_arrays_equal(tractogram1.streamlines, tractogram2.streamlines) - assert_arrays_equal(tractogram1.data_per_streamline['mean_curv'], - tractogram2.data_per_streamline['mean_curv']) - assert_arrays_equal(tractogram1.data_per_streamline['mean_color'], - tractogram2.data_per_streamline['mean_color']) - assert_arrays_equal(tractogram1.data_per_point['colors'], - tractogram2.data_per_point['colors']) + assert_tractogram_equal(tractogram, DATA['tractogram']) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 1b23843c5b..fc86879306 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -8,6 +8,136 @@ from .array_sequence import ArraySequence +class DataDict(collections.MutableMapping): + """ Dictionary that makes sure data are 2D array. + + This container behaves like a standard dictionary but it makes sure its + elements are ndarrays. In addition, it makes sure the amount of data + contained in those ndarrays matches the number of streamlines of the + :class:`Tractogram` object provided at the instantiation of this + dictionary. + """ + def __init__(self, tractogram, *args, **kwargs): + self.tractogram = tractogram + self.store = dict() + + # Use update to set the keys. + if len(args) == 1: + if isinstance(args[0], DataDict): + self.update(dict(args[0].store.items())) + elif args[0] is None: + return + else: + self.update(dict(*args, **kwargs)) + else: + self.update(dict(*args, **kwargs)) + + def __getitem__(self, key): + try: + return self.store[key] + except KeyError: + pass # Maybe it is an integer. + except TypeError: + pass # Maybe it is an object for advanced indexing. + + # Try to interpret key as an index/slice in which case we + # perform (advanced) indexing on every element of the dictionnary. + try: + idx = key + new_dict = type(self)(None) + for k, v in self.items(): + new_dict[k] = v[idx] + + return new_dict + except TypeError: + pass + + # That means key was not an index/slice after all. + return self.store[key] # Will raise the proper error. + + def __delitem__(self, key): + del self.store[key] + + def __iter__(self): + return iter(self.store) + + def __len__(self): + return len(self.store) + + +class DataPerStreamlineDict(DataDict): + """ Dictionary that makes sure data are 2D array. + + This container behaves like a standard dictionary but it makes sure its + elements are ndarrays. In addition, it makes sure the amount of data + contained in those ndarrays matches the number of streamlines of the + :class:`Tractogram` object provided at the instantiation of this + dictionary. + """ + def __setitem__(self, key, value): + value = np.asarray(value) + + if value.ndim == 1 and value.dtype != object: + # Reshape without copy + value.shape = ((len(value), 1)) + + if value.ndim != 2: + raise ValueError("data_per_streamline must be a 2D array.") + + # We make sure there is the right amount of values + # (i.e. same as the number of streamlines in the tractogram). + if self.tractogram is not None and len(value) != len(self.tractogram): + msg = ("The number of values ({0}) should match the number of" + " streamlines ({1}).") + raise ValueError(msg.format(len(value), len(self.tractogram))) + + self.store[key] = value + + +class DataPerPointDict(DataDict): + """ Dictionary making sure data are :class:`ArraySequence` objects. + + This container behaves like a standard dictionary but it makes sure its + elements are :class:`ArraySequence` objects. In addition, it makes sure + the amount of data contained in those :class:`ArraySequence` objects + matches the the number of points of the :class:`Tractogram` object + provided at the instantiation of this dictionary. + """ + + def __setitem__(self, key, value): + value = ArraySequence(value) + + # We make sure we have the right amount of values (i.e. same as + # the total number of points of all streamlines in the tractogram). + if (self.tractogram is not None and + len(value._data) != len(self.tractogram.streamlines._data)): + msg = ("The number of values ({0}) should match the total" + " number of points of all streamlines ({1}).") + nb_streamlines_points = self.tractogram.streamlines._data + raise ValueError(msg.format(len(value._data), + len(nb_streamlines_points))) + + self.store[key] = value + + +class LazyDict(DataDict): + """ Dictionary of coroutines with lazy evaluation. + + This container behaves like an dictionary but it makes sure its elements + are callable objects and assumed to be coroutines yielding values. When + getting the element associated to a given key, the element (i.e. a + coroutine) is first called before being returned. + """ + def __getitem__(self, key): + return self.store[key]() + + def __setitem__(self, key, value): + if value is not None and not callable(value): + raise TypeError("`value` must be a coroutine or None.") + + self.store[key] = value + + class TractogramItem(object): """ Class containing information about one streamline. @@ -48,73 +178,6 @@ class Tractogram(object): `data_per_streamline` and `data_per_point`. """ - class DataDict(collections.MutableMapping): - def __init__(self, tractogram, *args, **kwargs): - self.tractogram = tractogram - self.store = dict() - - # Use update to set the keys. - if len(args) == 1: - if isinstance(args[0], Tractogram.DataDict): - self.update(dict(args[0].store.items())) - elif args[0] is None: - return - else: - self.update(dict(*args, **kwargs)) - else: - self.update(dict(*args, **kwargs)) - - def __getitem__(self, key): - return self.store[key] - - def __delitem__(self, key): - del self.store[key] - - def __iter__(self): - return iter(self.store) - - def __len__(self): - return len(self.store) - - class DataPerStreamlineDict(DataDict): - """ Dictionary that makes sure data are 2D array. """ - - def __setitem__(self, key, value): - value = np.asarray(value) - - if value.ndim == 1 and value.dtype != object: - # Reshape without copy - value.shape = ((len(value), 1)) - - if value.ndim != 2: - raise ValueError("data_per_streamline must be a 2D array.") - - # We make sure there is the right amount of values - # (i.e. same as the number of streamlines in the tractogram). - if len(value) != len(self.tractogram): - msg = ("The number of values ({0}) should match the number of" - " streamlines ({1}).") - raise ValueError(msg.format(len(value), len(self.tractogram))) - - self.store[key] = value - - class DataPerPointDict(DataDict): - """ Dictionary making sure data are :class:`ArraySequence` objects. """ - - def __setitem__(self, key, value): - value = ArraySequence(value) - - # We make sure we have the right amount of values (i.e. same as - # the total number of points of all streamlines in the tractogram). - if len(value._data) != len(self.tractogram.streamlines._data): - msg = ("The number of values ({0}) should match the total" - " number of points of all streamlines ({1}).") - nb_streamlines_points = self.tractogram.streamlines._data - raise ValueError(msg.format(len(value._data), - len(nb_streamlines_points))) - - self.store[key] = value - def __init__(self, streamlines=None, data_per_streamline=None, data_per_point=None, @@ -159,8 +222,7 @@ def data_per_streamline(self): @data_per_streamline.setter def data_per_streamline(self, value): - self._data_per_streamline = Tractogram.DataPerStreamlineDict(self, - value) + self._data_per_streamline = DataPerStreamlineDict(self, value) @property def data_per_point(self): @@ -168,7 +230,7 @@ def data_per_point(self): @data_per_point.setter def data_per_point(self, value): - self._data_per_point = Tractogram.DataPerPointDict(self, value) + self._data_per_point = DataPerPointDict(self, value) def get_affine_to_rasmm(self): """ Returns the affine bringing this tractogram to RAS+mm. """ @@ -257,37 +319,6 @@ class LazyTractogram(Tractogram): If provided, `scalars` and `properties` must yield the same number of values as `streamlines`. """ - - class LazyDict(collections.MutableMapping): - """ Internal dictionary with lazy evaluations. """ - - def __init__(self, *args, **kwargs): - self.store = dict() - - # Use update to set keys. - if len(args) == 1 and isinstance(args[0], LazyTractogram.LazyDict): - self.update(dict(args[0].store.items())) - else: - self.update(dict(*args, **kwargs)) - - def __getitem__(self, key): - return self.store[key]() - - def __setitem__(self, key, value): - if value is not None and not callable(value): - raise TypeError("`value` must be a coroutine or None.") - - self.store[key] = value - - def __delitem__(self, key): - del self.store[key] - - def __iter__(self): - return iter(self.store) - - def __len__(self): - return len(self.store) - def __init__(self, streamlines=None, data_per_streamline=None, data_per_point=None): @@ -426,7 +457,7 @@ def data_per_streamline(self, value): if value is None: value = {} - self._data_per_streamline = LazyTractogram.LazyDict(value) + self._data_per_streamline = LazyDict(self, value) @property def data_per_point(self): @@ -437,7 +468,7 @@ def data_per_point(self, value): if value is None: value = {} - self._data_per_point = LazyTractogram.LazyDict(value) + self._data_per_point = LazyDict(self, value) @property def data(self): From 36bff9fd164c540942f9acef4c3cf31289803055 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 19 Feb 2016 10:14:04 -0500 Subject: [PATCH 084/135] Rebased master --- nibabel/testing/__init__.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index e5332014de..fa3454d473 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -14,7 +14,6 @@ import sys import warnings from os.path import dirname, abspath, join as pjoin -from nibabel.externals.six.moves import zip_longest import numpy as np from numpy.testing import assert_array_equal @@ -27,6 +26,8 @@ except ImportError: pass +from nibabel.externals.six.moves import zip_longest + # set path to example data data_path = abspath(pjoin(dirname(__file__), '..', 'tests', 'data')) @@ -64,6 +65,22 @@ def assert_allclose_safely(a, b, match_nans=True, rtol=1e-5, atol=1e-8): assert_true(np.allclose(a, b, rtol=rtol, atol=atol)) +def check_iteration(iterable): + """ Checks that an object can be iterated through without errors. """ + try: + for _ in iterable: + pass + except: + return False + + return True + + +def assert_arrays_equal(arrays1, arrays2): + for arr1, arr2 in zip_longest(arrays1, arrays2, fillvalue=None): + assert_array_equal(arr1, arr2) + + def assert_re_in(regex, c, flags=0): """Assert that container (list, str, etc) contains entry matching the regex """ @@ -75,7 +92,6 @@ def assert_re_in(regex, c, flags=0): raise AssertionError("Not a single entry matched %r in %r" % (regex, c)) - def get_fresh_mod(mod_name=__name__): # Get this module, with warning registry empty my_mod = sys.modules[mod_name] @@ -179,7 +195,17 @@ class suppress_warnings(error_warnings): class catch_warn_reset(clear_and_catch_warnings): + def __init__(self, *args, **kwargs): warnings.warn('catch_warn_reset is deprecated and will be removed in ' 'nibabel v3.0; use nibabel.testing.clear_and_catch_warnings.', FutureWarning) + + +EXTRA_SET = os.environ.get('NIPY_EXTRA_TESTS', '').split(',') + + +def runif_extra_has(test_str): + """Decorator checks to see if NIPY_EXTRA_TESTS env var contains test_str""" + return skipif(test_str not in EXTRA_SET, + "Skip {0} tests.".format(test_str)) From 5faf62ecad30888bae376744e06483ef1ce77870 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 19 Feb 2016 23:30:06 -0500 Subject: [PATCH 085/135] Support numpy 1.5 and make code PEP8 compliant. --- nibabel/streamlines/array_sequence.py | 18 +- .../streamlines/tests/test_array_sequence.py | 12 +- nibabel/streamlines/tractogram.py | 7 +- nibabel/streamlines/trk.py | 173 +++++++++++------- 4 files changed, 126 insertions(+), 84 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 32521eb20c..f51d936fdf 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -22,7 +22,7 @@ class ArraySequence(object): same for every ndarray. """ - BUFFER_SIZE = 87382*4 # About 4 Mb if item shape is 3 (e.g. 3D points). + BUFFER_SIZE = 87382 * 4 # About 4 Mb if item shape is 3 (e.g. 3D points). def __init__(self, iterable=None): """ @@ -72,7 +72,7 @@ def __init__(self, iterable=None): offsets.append(offset) lengths.append(len(e)) - self._data[offset:offset+len(e)] = e + self._data[offset:offset + len(e)] = e offset += len(e) self._offsets = np.asarray(offsets) @@ -148,14 +148,14 @@ def extend(self, elements): next_offset = self._data.shape[0] if is_array_sequence(elements): - self._data.resize((self._data.shape[0]+sum(elements._lengths), + self._data.resize((self._data.shape[0] + sum(elements._lengths), self._data.shape[1])) offsets = [] for offset, length in zip(elements._offsets, elements._lengths): offsets.append(next_offset) - chunk = elements._data[offset:offset+length] - self._data[next_offset:next_offset+length] = chunk + chunk = elements._data[offset:offset + length] + self._data[next_offset:next_offset + length] = chunk next_offset += length self._lengths = np.r_[self._lengths, elements._lengths] @@ -182,8 +182,8 @@ def copy(self): offsets = [] for offset, length in zip(self._offsets, self._lengths): offsets.append(next_offset) - chunk = self._data[offset:offset+length] - seq._data[next_offset:next_offset+length] = chunk + chunk = self._data[offset:offset + length] + seq._data[next_offset:next_offset + length] = chunk next_offset += length seq._offsets = np.asarray(offsets) @@ -212,7 +212,7 @@ def __getitem__(self, idx): """ if isinstance(idx, (int, np.integer)): start = self._offsets[idx] - return self._data[start:start+self._lengths[idx]] + return self._data[start:start + self._lengths[idx]] elif isinstance(idx, (slice, list)): seq = ArraySequence() @@ -241,7 +241,7 @@ def __iter__(self): " len(self._lengths) != len(self._offsets)") for offset, lengths in zip(self._offsets, self._lengths): - yield self._data[offset: offset+lengths] + yield self._data[offset: offset + lengths] def __len__(self): return len(self._offsets) diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 245d2afd53..4cf41e7e05 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -18,12 +18,12 @@ def setup(): global SEQ_DATA rng = np.random.RandomState(42) SEQ_DATA['rng'] = rng - SEQ_DATA['data'] = generate_data(nb_arrays=10, common_shape=(3,), rng=rng) + SEQ_DATA['data'] = generate_data(nb_arrays=5, common_shape=(3,), rng=rng) SEQ_DATA['seq'] = ArraySequence(SEQ_DATA['data']) def generate_data(nb_arrays, common_shape, rng): - data = [rng.rand(*(rng.randint(10, 50),) + common_shape) + data = [rng.rand(*(rng.randint(3, 20),) + common_shape) for _ in range(nb_arrays)] return data @@ -79,7 +79,7 @@ def test_creating_arraysequence_from_list(self): for ndim in range(0, N+1): common_shape = tuple([SEQ_DATA['rng'].randint(1, 10) for _ in range(ndim-1)]) - data = generate_data(nb_arrays=10, common_shape=common_shape, + data = generate_data(nb_arrays=5, common_shape=common_shape, rng=SEQ_DATA['rng']) check_arr_seq(ArraySequence(data), data) @@ -213,7 +213,7 @@ def test_arraysequence_getitem(self): SEQ_DATA['rng'].shuffle(indices) seq_view = SEQ_DATA['seq'][indices] check_arr_seq_view(seq_view, SEQ_DATA['seq']) - check_arr_seq(seq_view, np.asarray(SEQ_DATA['data'])[indices]) + check_arr_seq(seq_view, [SEQ_DATA['data'][i] for i in indices]) # Get slice (this will create a view). seq_view = SEQ_DATA['seq'][::2] @@ -224,7 +224,9 @@ def test_arraysequence_getitem(self): selection = np.array([False, True, True, False, True]) seq_view = SEQ_DATA['seq'][selection] check_arr_seq_view(seq_view, SEQ_DATA['seq']) - check_arr_seq(seq_view, np.asarray(SEQ_DATA['data'])[selection]) + check_arr_seq(seq_view, + [SEQ_DATA['data'][i] + for i, keep in enumerate(selection) if keep]) # Test invalid indexing assert_raises(TypeError, SEQ_DATA['seq'].__getitem__, 'abc') diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index fc86879306..e620682296 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -295,9 +295,10 @@ def apply_affine(self, affine, lazy=False): return self BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. - for i in range(0, len(self.streamlines._data), BUFFER_SIZE): - pts = self.streamlines._data[i:i+BUFFER_SIZE] - self.streamlines._data[i:i+BUFFER_SIZE] = apply_affine(affine, pts) + for start in range(0, len(self.streamlines._data), BUFFER_SIZE): + end = start + BUFFER_SIZE + pts = self.streamlines._data[start:end] + self.streamlines._data[start:end] = apply_affine(affine, pts) # Update the affine that brings back the streamlines to RASmm. self._affine_to_rasmm = np.dot(self._affine_to_rasmm, diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index b9ea297b53..bd501c9114 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -35,7 +35,8 @@ (Field.NB_SCALARS_PER_POINT, 'h'), ('scalar_name', 'S20', MAX_NB_NAMED_SCALARS_PER_POINT), (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), - ('property_name', 'S20', MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), + ('property_name', 'S20', + MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), ('reserved', 'S508'), (Field.VOXEL_ORDER, 'S4'), ('pad2', 'S4'), @@ -63,7 +64,8 @@ (Field.NB_SCALARS_PER_POINT, 'h'), ('scalar_name', 'S20', MAX_NB_NAMED_SCALARS_PER_POINT), (Field.NB_PROPERTIES_PER_STREAMLINE, 'h'), - ('property_name', 'S20', MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), + ('property_name', 'S20', + MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE), (Field.VOXEL_TO_RASMM, 'f4', (4, 4)), # New in version 2. ('reserved', 'S444'), (Field.VOXEL_ORDER, 'S4'), @@ -171,9 +173,11 @@ def __iter__(self): with Opener(self.fileobj) as f: start_position = f.tell() - nb_pts_and_scalars = int(3 + self.header[Field.NB_SCALARS_PER_POINT]) + nb_pts_and_scalars = int(3 + + self.header[Field.NB_SCALARS_PER_POINT]) pts_and_scalars_size = int(nb_pts_and_scalars * f4_dtype.itemsize) - properties_size = int(self.header[Field.NB_PROPERTIES_PER_STREAMLINE] * f4_dtype.itemsize) + nb_properties = self.header[Field.NB_PROPERTIES_PER_STREAMLINE] + properties_size = int(nb_properties * f4_dtype.itemsize) # Set the file position at the beginning of the data. f.seek(self.offset_data, os.SEEK_SET) @@ -197,17 +201,19 @@ def __iter__(self): nb_pts = struct.unpack(nb_pts_dtype, nb_pts_str)[0] # Read streamline's data - points_and_scalars = np.ndarray(shape=(nb_pts, nb_pts_and_scalars), - dtype=f4_dtype, - buffer=f.read(nb_pts * pts_and_scalars_size)) + points_and_scalars = np.ndarray( + shape=(nb_pts, nb_pts_and_scalars), + dtype=f4_dtype, + buffer=f.read(nb_pts * pts_and_scalars_size)) points = points_and_scalars[:, :3] scalars = points_and_scalars[:, 3:] # Read properties - properties = np.ndarray(shape=(self.header[Field.NB_PROPERTIES_PER_STREAMLINE],), - dtype=f4_dtype, - buffer=f.read(properties_size)) + properties = np.ndarray( + shape=(self.header[Field.NB_PROPERTIES_PER_STREAMLINE],), + dtype=f4_dtype, + buffer=f.read(properties_size)) yield points, scalars, properties i += 1 @@ -283,13 +289,15 @@ def write(self, tractogram): # tractogram. data_for_streamline = first_item.data_for_streamline if len(data_for_streamline) > MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: - raise ValueError(("Can only store {0} named data_per_streamline" - " (also known as 'properties' in the TRK" - " format).").format(MAX_NB_NAMED_SCALARS_PER_POINT)) + msg = ("Can only store {0} named data_per_streamline (also known" + " as 'properties' in the TRK format)." + ).format(MAX_NB_NAMED_SCALARS_PER_POINT) + raise ValueError(msg) data_for_streamline_keys = sorted(data_for_streamline.keys()) - self.header['property_name'] = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, - dtype='S20') + self.header['property_name'] = np.zeros( + MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, + dtype='S20') for i, k in enumerate(data_for_streamline_keys): nb_values = data_for_streamline[k].shape[0] @@ -305,8 +313,8 @@ def write(self, tractogram): if nb_values > 1: # Use the last to bytes of the name to store the nb of values # associated to this data_for_streamline. - property_name = asbytes(k[:18].ljust(18, '\x00')) + b'\x00' \ - + np.array(nb_values, dtype=np.int8).tostring() + property_name = (asbytes(k[:18].ljust(18, '\x00')) + b'\x00' + + np.array(nb_values, dtype=np.int8).tostring()) self.header['property_name'][i] = property_name @@ -336,8 +344,8 @@ def write(self, tractogram): if nb_values > 1: # Use the last to bytes of the name to store the nb of values # associated to this data_for_streamline. - scalar_name = asbytes(k[:18].ljust(18, '\x00')) + b'\x00' \ - + np.array(nb_values, dtype=np.int8).tostring() + scalar_name = (asbytes(k[:18].ljust(18, '\x00')) + b'\x00' + + np.array(nb_values, dtype=np.int8).tostring()) self.header['scalar_name'][i] = scalar_name @@ -379,14 +387,21 @@ def write(self, tractogram): affine = affine.astype(np.float32) for t in tractogram: - if any((len(d) != len(t.streamline) for d in t.data_for_points.values())): + if any((len(d) != len(t.streamline) + for d in t.data_for_points.values())): raise DataError("Missing scalars for some points!") - points = apply_affine(affine, np.asarray(t.streamline, dtype=f4_dtype)) - scalars = [np.asarray(t.data_for_points[k], dtype=f4_dtype) for k in data_for_points_keys] - scalars = np.concatenate([np.ndarray((len(points), 0), dtype=f4_dtype)] + scalars, axis=1) - properties = [np.asarray(t.data_for_streamline[k], dtype=f4_dtype) for k in data_for_streamline_keys] - properties = np.concatenate([np.array([], dtype=f4_dtype)] + properties) + points = apply_affine(affine, + np.asarray(t.streamline, dtype=f4_dtype)) + scalars = [np.asarray(t.data_for_points[k], dtype=f4_dtype) + for k in data_for_points_keys] + scalars = np.concatenate([np.ndarray((len(points), 0), + dtype=f4_dtype) + ] + scalars, axis=1) + properties = [np.asarray(t.data_for_streamline[k], dtype=f4_dtype) + for k in data_for_streamline_keys] + properties = np.concatenate([np.array([], dtype=f4_dtype) + ] + properties) data = struct.pack(i4_dtype.str[:-1], len(points)) data += np.concatenate([points, scalars], axis=1).tostring() @@ -437,7 +452,7 @@ class TrkFile(TractogramFile): # Contants MAGIC_NUMBER = b"TRACK" HEADER_SIZE = 1000 - READ_BUFFER_SIZE = 87382*4 # About 4 Mb if there is no scalars nor properties. + READ_BUFFER_SIZE = 87382 * 4 # About 4 Mb when no scalars nor properties. def __init__(self, tractogram, header=None): """ @@ -468,12 +483,12 @@ def get_magic_number(cls): @classmethod def support_data_per_point(cls): - """ Tells if this tractogram format supports saving data per point. """ + """ Tells if this format supports saving data per point. """ return True @classmethod def support_data_per_streamline(cls): - """ Tells if this tractogram format supports saving data per streamline. """ + """ Tells if this format supports saving data per streamline. """ return True @classmethod @@ -524,9 +539,12 @@ def _create_arraysequence_from_generator(cls, gen): scals_shape = scals.shape props_shape = props.shape - streamlines._data = np.empty((cls.READ_BUFFER_SIZE, pts.shape[1]), dtype=pts.dtype) - scalars._data = np.empty((cls.READ_BUFFER_SIZE, scals.shape[1]), dtype=scals.dtype) - properties = np.empty((cls.READ_BUFFER_SIZE, props.shape[0]), dtype=props.dtype) + streamlines._data = np.empty((cls.READ_BUFFER_SIZE, pts.shape[1]), + dtype=pts.dtype) + scalars._data = np.empty((cls.READ_BUFFER_SIZE, scals.shape[1]), + dtype=scals.dtype) + properties = np.empty((cls.READ_BUFFER_SIZE, props.shape[0]), + dtype=props.dtype) offset = 0 offsets = [] @@ -547,18 +565,23 @@ def _create_arraysequence_from_generator(cls, gen): end = offset + len(pts) if end >= len(streamlines._data): # Resize is needed (at least `len(pts)` items will be added). - streamlines._data.resize((len(streamlines._data) + len(pts)+cls.READ_BUFFER_SIZE, pts.shape[1])) - scalars._data.resize((len(scalars._data) + len(scals)+cls.READ_BUFFER_SIZE, scals.shape[1])) + streamlines._data.resize((len(streamlines._data) + len(pts) + + cls.READ_BUFFER_SIZE, + pts.shape[1])) + scalars._data.resize((len(scalars._data) + len(scals) + + cls.READ_BUFFER_SIZE, + scals.shape[1])) offsets.append(offset) lengths.append(len(pts)) - streamlines._data[offset:offset+len(pts)] = pts - scalars._data[offset:offset+len(scals)] = scals + streamlines._data[offset:offset + len(pts)] = pts + scalars._data[offset:offset + len(scals)] = scals offset += len(pts) if i >= len(properties): - properties.resize((len(properties) + cls.READ_BUFFER_SIZE, props.shape[0])) + properties.resize((len(properties) + cls.READ_BUFFER_SIZE, + props.shape[0])) properties[i] = props @@ -580,9 +603,9 @@ def _create_arraysequence_from_generator(cls, gen): if props_shape[0] == 0: # Because resizing an empty ndarray creates memory! - properties = np.empty((i+1, props.shape[0])) + properties = np.empty((i + 1, props.shape[0])) else: - properties.resize((i+1, props.shape[0])) + properties.resize((i + 1, props.shape[0])) return streamlines, scalars, properties @@ -613,6 +636,7 @@ def load(cls, fileobj, lazy_load=False): voxel. """ trk_reader = TrkReader(fileobj) + hdr = trk_reader.header # TRK's streamlines are in 'voxelmm' space, we will compute the # affine matrix that will bring them back to RAS+ and mm space. @@ -622,7 +646,7 @@ def load(cls, fileobj, lazy_load=False): # in the voxel space. # voxelmm -> voxel scale = np.eye(4) - scale[range(3), range(3)] /= trk_reader.header[Field.VOXEL_SIZES] + scale[range(3), range(3)] /= hdr[Field.VOXEL_SIZES] affine = np.dot(scale, affine) # TrackVis considers coordinate (0,0,0) to be the corner of the voxel @@ -635,23 +659,24 @@ def load(cls, fileobj, lazy_load=False): # If the voxel order implied by the affine does not match the voxel # order in the TRK header, change the orientation. # voxel (header) -> voxel (affine) - header_ornt = asstr(trk_reader.header[Field.VOXEL_ORDER]) - affine_ornt = "".join(aff2axcodes(trk_reader.header[Field.VOXEL_TO_RASMM])) + header_ornt = asstr(hdr[Field.VOXEL_ORDER]) + affine_ornt = "".join(aff2axcodes(hdr[Field.VOXEL_TO_RASMM])) header_ornt = axcodes2ornt(header_ornt) affine_ornt = axcodes2ornt(affine_ornt) ornt = nib.orientations.ornt_transform(header_ornt, affine_ornt) - M = nib.orientations.inv_ornt_aff(ornt, trk_reader.header[Field.DIMENSIONS]) + M = nib.orientations.inv_ornt_aff(ornt, + hdr[Field.DIMENSIONS]) affine = np.dot(M, affine) # Applied the affine found in the TRK header. # voxel -> rasmm - affine = np.dot(trk_reader.header[Field.VOXEL_TO_RASMM], affine) + affine = np.dot(hdr[Field.VOXEL_TO_RASMM], affine) # Find scalars and properties name data_per_point_slice = {} - if trk_reader.header[Field.NB_SCALARS_PER_POINT] > 0: + if hdr[Field.NB_SCALARS_PER_POINT] > 0: cpt = 0 - for scalar_name in trk_reader.header['scalar_name']: + for scalar_name in hdr['scalar_name']: scalar_name = asstr(scalar_name) if len(scalar_name) == 0: continue @@ -663,16 +688,18 @@ def load(cls, fileobj, lazy_load=False): nb_scalars = int(np.fromstring(scalar_name[-1], np.int8)) scalar_name = scalar_name.split('\x00')[0] - data_per_point_slice[scalar_name] = slice(cpt, cpt+nb_scalars) + data_per_point_slice[scalar_name] = slice(cpt, + cpt + nb_scalars) cpt += nb_scalars - if cpt < trk_reader.header[Field.NB_SCALARS_PER_POINT]: - data_per_point_slice['scalars'] = slice(cpt, trk_reader.header[Field.NB_SCALARS_PER_POINT]) + if cpt < hdr[Field.NB_SCALARS_PER_POINT]: + slice_obj = slice(cpt, hdr[Field.NB_SCALARS_PER_POINT]) + data_per_point_slice['scalars'] = slice_obj data_per_streamline_slice = {} - if trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: + if hdr[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: cpt = 0 - for property_name in trk_reader.header['property_name']: + for property_name in hdr['property_name']: property_name = asstr(property_name) if len(property_name) == 0: continue @@ -681,42 +708,50 @@ def load(cls, fileobj, lazy_load=False): # property name. nb_properties = 1 if property_name[-2] == '\x00' and property_name[-1] != '\x00': - nb_properties = int(np.fromstring(property_name[-1], np.int8)) + nb_properties = int(np.fromstring(property_name[-1], + np.int8)) property_name = property_name.split('\x00')[0] - data_per_streamline_slice[property_name] = slice(cpt, cpt+nb_properties) + slice_obj = slice(cpt, cpt + nb_properties) + data_per_streamline_slice[property_name] = slice_obj cpt += nb_properties - if cpt < trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]: - data_per_streamline_slice['properties'] = slice(cpt, trk_reader.header[Field.NB_PROPERTIES_PER_STREAMLINE]) + if cpt < hdr[Field.NB_PROPERTIES_PER_STREAMLINE]: + slice_obj = slice(cpt, hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + data_per_streamline_slice['properties'] = slice_obj if lazy_load: def _read(): for pts, scals, props in trk_reader: - data_for_points = dict((k, scals[:, v]) for k, v in data_per_point_slice.items()) - data_for_streamline = dict((k, props[v]) for k, v in data_per_streamline_slice.items()) - yield TractogramItem(pts, data_for_streamline, data_for_points) + items = data_per_point_slice.items() + data_for_points = dict((k, scals[:, v]) for k, v in items) + items = data_per_streamline_slice.items() + data_for_streamline = dict((k, props[v]) for k, v in items) + yield TractogramItem(pts, + data_for_streamline, + data_for_points) tractogram = LazyTractogram.create_from(_read) else: - streamlines, scalars, properties = cls._create_arraysequence_from_generator(trk_reader) + arr_seqs = cls._create_arraysequence_from_generator(trk_reader) + streamlines, scalars, properties = arr_seqs tractogram = Tractogram(streamlines) - for scalar_name, slice_ in data_per_point_slice.items(): + for name, slice_ in data_per_point_slice.items(): seq = ArraySequence() seq._data = scalars._data[:, slice_] seq._offsets = scalars._offsets seq._lengths = scalars._lengths - tractogram.data_per_point[scalar_name] = seq + tractogram.data_per_point[name] = seq - for property_name, slice_ in data_per_streamline_slice.items(): - tractogram.data_per_streamline[property_name] = properties[:, slice_] + for name, slice_ in data_per_streamline_slice.items(): + tractogram.data_per_streamline[name] = properties[:, slice_] # Bring tractogram to RAS+ and mm space tractogram.apply_affine(affine.astype(np.float32)) - return cls(tractogram, header=trk_reader.header) + return cls(tractogram, header=hdr) def save(self, fileobj): """ Saves tractogram to a file-like object using TRK format. @@ -755,12 +790,16 @@ def __str__(self): info += "\nvoxel_sizes: {0}".format(hdr[Field.VOXEL_SIZES]) info += "\norgin: {0}".format(hdr[Field.ORIGIN]) info += "\nnb_scalars: {0}".format(hdr[Field.NB_SCALARS_PER_POINT]) - info += "\nscalar_name:\n {0}".format("\n".join(map(asstr, hdr['scalar_name']))) - info += "\nnb_properties: {0}".format(hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) - info += "\nproperty_name:\n {0}".format("\n".join(map(asstr, hdr['property_name']))) + info += "\nscalar_name:\n {0}".format( + "\n".join(map(asstr, hdr['scalar_name']))) + info += "\nnb_properties: {0}".format( + hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) + info += "\nproperty_name:\n {0}".format( + "\n".join(map(asstr, hdr['property_name']))) info += "\nvox_to_world: {0}".format(hdr[Field.VOXEL_TO_RASMM]) info += "\nvoxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) - info += "\nimage_orientation_patient: {0}".format(hdr['image_orientation_patient']) + info += "\nimage_orientation_patient: {0}".format( + hdr['image_orientation_patient']) info += "\npad1: {0}".format(hdr['pad1']) info += "\npad2: {0}".format(hdr['pad2']) info += "\ninvert_x: {0}".format(hdr['invert_x']) From 7ba65c28e1e0432b28b75dbe91b7004042cfb9fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 23 Feb 2016 19:35:44 -0500 Subject: [PATCH 086/135] Addressed @matthew-brett's about coroutine and ArraySequence.extend method. --- nibabel/streamlines/__init__.py | 2 +- nibabel/streamlines/array_sequence.py | 77 ++++++-------- .../streamlines/tests/test_array_sequence.py | 18 +++- nibabel/streamlines/tests/test_streamlines.py | 48 +++++---- nibabel/streamlines/tests/test_tractogram.py | 32 ++++-- nibabel/streamlines/tractogram.py | 100 +++++++++++------- 6 files changed, 169 insertions(+), 108 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 8119865c4a..51696979e7 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -127,6 +127,6 @@ def save(tractogram, filename, **kwargs): if len(kwargs) > 0: msg = ("A 'TractogramFile' object was provided, no need for" " keyword arguments.") - raise ValueError(msg) + raise ValueError(msg) tractogram_file.save(filename) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index f51d936fdf..f7aa2ff88a 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -1,3 +1,4 @@ +import numbers import numpy as np @@ -37,7 +38,7 @@ def __init__(self, iterable=None): """ # Create new empty `ArraySequence` object. self._is_view = False - self._data = np.array(0) + self._data = np.array([]) self._offsets = np.array([], dtype=np.intp) self._lengths = np.array([], dtype=np.intp) @@ -79,8 +80,7 @@ def __init__(self, iterable=None): self._lengths = np.asarray(lengths) # Clear unused memory. - if self._data.ndim != 0: - self._data.resize((offset,) + self.common_shape) + self._data.resize((offset,) + self.common_shape) @property def is_array_sequence(self): @@ -89,13 +89,10 @@ def is_array_sequence(self): @property def common_shape(self): """ Matching shape of the elements in this array sequence. """ - if self._data.ndim == 0: - return () - return self._data.shape[1:] def append(self, element): - """ Appends :obj:`element` to this array sequence. + """ Appends `element` to this array sequence. Parameters ---------- @@ -108,28 +105,28 @@ def append(self, element): If you need to add multiple elements you should consider `ArraySequence.extend`. """ - if self._data.ndim == 0: - self._data = np.asarray(element).copy() - self._offsets = np.array([0]) - self._lengths = np.array([len(element)]) - return + element = np.asarray(element) - if element.shape[1:] != self.common_shape: + if self.common_shape != () and element.shape[1:] != self.common_shape: msg = "All dimensions, except the first one, must match exactly" raise ValueError(msg) - self._offsets = np.r_[self._offsets, len(self._data)] - self._lengths = np.r_[self._lengths, len(element)] - self._data = np.append(self._data, element, axis=0) + next_offset = self._data.shape[0] + size = (self._data.shape[0] + element.shape[0],) + element.shape[1:] + self._data.resize(size) + self._data[next_offset:] = element + self._offsets = np.r_[self._offsets, next_offset] + self._lengths = np.r_[self._lengths, element.shape[0]] def extend(self, elements): """ Appends all `elements` to this array sequence. Parameters ---------- - elements : list of ndarrays or :class:`ArraySequence` object - If list of ndarrays, each ndarray will be concatenated along the - first dimension then appended to the data of this ArraySequence. + elements : iterable of ndarrays or :class:`ArraySequence` object + If iterable of ndarrays, each ndarray will be concatenated along + the first dimension then appended to the data of this + ArraySequence. If :class:`ArraySequence` object, its data are simply appended to the data of this ArraySequence. @@ -138,35 +135,31 @@ def extend(self, elements): The shape of the elements to be added must match the one of the data of this :class:`ArraySequence` except for the first dimension. """ + if not is_array_sequence(elements): + self.extend(ArraySequence(elements)) + return + if len(elements) == 0: return - if self._data.ndim == 0: - elem = np.asarray(elements[0]) - self._data = np.zeros((0, elem.shape[1]), dtype=elem.dtype) + if (self.common_shape != () and + elements.common_shape != self.common_shape): + msg = "All dimensions, except the first one, must match exactly" + raise ValueError(msg) next_offset = self._data.shape[0] + self._data.resize((self._data.shape[0] + sum(elements._lengths), + elements._data.shape[1])) - if is_array_sequence(elements): - self._data.resize((self._data.shape[0] + sum(elements._lengths), - self._data.shape[1])) - - offsets = [] - for offset, length in zip(elements._offsets, elements._lengths): - offsets.append(next_offset) - chunk = elements._data[offset:offset + length] - self._data[next_offset:next_offset + length] = chunk - next_offset += length - - self._lengths = np.r_[self._lengths, elements._lengths] - self._offsets = np.r_[self._offsets, offsets] + offsets = [] + for offset, length in zip(elements._offsets, elements._lengths): + offsets.append(next_offset) + chunk = elements._data[offset:offset + length] + self._data[next_offset:next_offset + length] = chunk + next_offset += length - else: - self._data = np.concatenate([self._data] + list(elements), axis=0) - lengths = list(map(len, elements)) - self._lengths = np.r_[self._lengths, lengths] - self._offsets = np.r_[self._offsets, - np.cumsum([next_offset] + lengths)[:-1]] + self._lengths = np.r_[self._lengths, elements._lengths] + self._offsets = np.r_[self._offsets, offsets] def copy(self): """ Creates a copy of this :class:`ArraySequence` object. """ @@ -210,7 +203,7 @@ def __getitem__(self, idx): Otherwise, returns a :class:`ArraySequence` object which is view of the selected sequences. """ - if isinstance(idx, (int, np.integer)): + if isinstance(idx, (numbers.Integral, np.integer)): start = self._offsets[idx] return self._data[start:start + self._lengths[idx]] diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 4cf41e7e05..8639ff58f8 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -1,4 +1,5 @@ import os +import sys import unittest import tempfile import numpy as np @@ -6,7 +7,6 @@ from nose.tools import assert_equal, assert_raises, assert_true from nibabel.testing import assert_arrays_equal from numpy.testing import assert_array_equal -from nibabel.externals.six.moves import zip, zip_longest from ..array_sequence import ArraySequence, is_array_sequence @@ -32,7 +32,8 @@ def check_empty_arr_seq(seq): assert_equal(len(seq), 0) assert_equal(len(seq._offsets), 0) assert_equal(len(seq._lengths), 0) - assert_equal(seq._data.ndim, 0) + # assert_equal(seq._data.ndim, 0) + assert_equal(seq._data.ndim, 1) assert_true(seq.common_shape == ()) @@ -138,6 +139,11 @@ def test_arraysequence_append(self): seq.append(element) check_arr_seq(seq, SEQ_DATA['data'] + [element]) + # Append a list of list. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.append(element.tolist()) + check_arr_seq(seq, SEQ_DATA['data'] + [element]) + # Append to an empty ArraySequence. seq = ArraySequence() seq.append(element) @@ -164,6 +170,11 @@ def test_arraysequence_extend(self): seq.extend(new_data) check_arr_seq(seq, SEQ_DATA['data'] + new_data) + # Extend with a generator. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + seq.extend((d for d in new_data)) + check_arr_seq(seq, SEQ_DATA['data'] + new_data) + # Extend with another `ArraySequence` object. seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. seq.extend(ArraySequence(new_data)) @@ -195,6 +206,9 @@ def test_arraysequence_getitem(self): for i, e in enumerate(SEQ_DATA['seq']): assert_array_equal(SEQ_DATA['seq'][i], e) + if sys.version_info < (3,): + assert_array_equal(SEQ_DATA['seq'][long(i)], e) + # Get all items using indexing (creates a view). indices = list(range(len(SEQ_DATA['seq']))) seq_view = SEQ_DATA['seq'][indices] diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 32f0e1a736..cecc5c295e 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -226,23 +226,28 @@ def test_save_tractogram_file(self): assert_true(issubclass(w[0].category, ExtensionWarning)) assert_true("extension" in str(w[0].message)) + with InTemporaryDirectory(): + nib.streamlines.save(trk_file, "dummy.trk") + tfile = nib.streamlines.load("dummy.trk", lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) + def test_save_empty_file(self): tractogram = Tractogram() for ext, cls in nib.streamlines.FORMATS.items(): with InTemporaryDirectory(): - with open('streamlines' + ext, 'w+b') as f: - nib.streamlines.save(tractogram, f.name) - tfile = nib.streamlines.load(f, lazy_load=False) - assert_tractogram_equal(tfile.tractogram, tractogram) + filename = 'streamlines' + ext + nib.streamlines.save(tractogram, filename) + tfile = nib.streamlines.load(filename, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_simple_file(self): tractogram = Tractogram(DATA['streamlines']) for ext, cls in nib.streamlines.FORMATS.items(): with InTemporaryDirectory(): - with open('streamlines' + ext, 'w+b') as f: - nib.streamlines.save(tractogram, f.name) - tfile = nib.streamlines.load(f, lazy_load=False) - assert_tractogram_equal(tfile.tractogram, tractogram) + filename = 'streamlines' + ext + nib.streamlines.save(tractogram, filename) + tfile = nib.streamlines.load(filename, lazy_load=False) + assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_complex_file(self): complex_tractogram = Tractogram(DATA['streamlines'], @@ -251,18 +256,19 @@ def test_save_complex_file(self): for ext, cls in nib.streamlines.FORMATS.items(): with InTemporaryDirectory(): - with open('streamlines' + ext, 'w+b') as f: - with clear_and_catch_warnings(record=True, - modules=[trk]) as w: - nib.streamlines.save(complex_tractogram, f.name) - - # If streamlines format does not support saving data - # per point or data per streamline, a warning message - # should be issued. - if not (cls.support_data_per_point() and - cls.support_data_per_streamline()): - assert_equal(len(w), 1) - assert_true(issubclass(w[0].category, Warning)) + filename = 'streamlines' + ext + + with clear_and_catch_warnings(record=True, + modules=[trk]) as w: + nib.streamlines.save(complex_tractogram, filename) + + # If streamlines format does not support saving data + # per point or data per streamline, a warning message + # should be issued. + if not (cls.support_data_per_point() and + cls.support_data_per_streamline()): + assert_equal(len(w), 1) + assert_true(issubclass(w[0].category, Warning)) tractogram = Tractogram(DATA['streamlines']) @@ -272,7 +278,7 @@ def test_save_complex_file(self): if cls.support_data_per_streamline(): tractogram.data_per_streamline = DATA['data_per_streamline'] - tfile = nib.streamlines.load(f, lazy_load=False) + tfile = nib.streamlines.load(filename, lazy_load=False) assert_tractogram_equal(tfile.tractogram, tractogram) def test_load_unknown_format(self): diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index bf4b936fad..a5bd1c1a03 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -1,3 +1,4 @@ +import sys import unittest import numpy as np import warnings @@ -241,6 +242,22 @@ def test_tractogram_creation(self): DATA['data_per_streamline'], DATA['data_per_point']) + # Create a tractogram from another tractogram attributes. + tractogram2 = Tractogram(tractogram.streamlines, + tractogram.data_per_streamline, + tractogram.data_per_point) + + assert_tractogram_equal(tractogram2, tractogram) + + # Create a tractogram from a LazyTractogram object. + tractogram = LazyTractogram(DATA['streamlines_func'], + DATA['data_per_streamline_func'], + DATA['data_per_point_func']) + + tractogram2 = Tractogram(tractogram.streamlines, + tractogram.data_per_streamline, + tractogram.data_per_point) + # Inconsistent number of scalars between streamlines wrong_data = [[(1, 0, 0)]*1, [(0, 1, 0), (0, 1)], @@ -264,6 +281,9 @@ def test_tractogram_getitem(self): for i, t in enumerate(DATA['tractogram']): assert_tractogram_item_equal(DATA['tractogram'][i], t) + if sys.version_info < (3,): + assert_tractogram_item_equal(DATA['tractogram'][long(i)], t) + # Get one TractogramItem out of two. tractogram_view = DATA['simple_tractogram'][::2] check_tractogram(tractogram_view, DATA['streamlines'][::2]) @@ -411,7 +431,7 @@ def test_lazy_tractogram_creation(self): 'mean_colors': (x for x in DATA['mean_colors'])} # Creating LazyTractogram with generators is not allowed as - # generators get exhausted and are not reusable unlike coroutines. + # generators get exhausted and are not reusable unlike generator function. assert_raises(TypeError, LazyTractogram, streamlines) assert_raises(TypeError, LazyTractogram, data_per_streamline=data_per_streamline) @@ -430,7 +450,7 @@ def test_lazy_tractogram_creation(self): assert_true(check_iteration(tractogram)) assert_equal(len(tractogram), len(DATA['streamlines'])) - # Coroutines get re-called and creates new iterators. + # Generator functions get re-called and creates new iterators. for i in range(2): assert_tractogram_equal(tractogram, DATA['tractogram']) @@ -441,7 +461,7 @@ def test_lazy_tractogram_create_from(self): tractogram = LazyTractogram.create_from(_empty_data_gen) check_tractogram(tractogram) - # Create `LazyTractogram` from a coroutine yielding TractogramItem + # Create `LazyTractogram` from a generator function yielding TractogramItem. data = [DATA['streamlines'], DATA['fa'], DATA['colors'], DATA['mean_curvature'], DATA['mean_torsion'], DATA['mean_colors']] @@ -530,13 +550,13 @@ def test_lazy_tractogram_copy(self): # Check we copied the data and not simply created new references. assert_true(tractogram is not DATA['lazy_tractogram']) - # When copying LazyTractogram, coroutines generating streamlines should - # be the same. + # When copying LazyTractogram, the generator function yielding streamlines + # should stay the same. assert_true(tractogram._streamlines is DATA['lazy_tractogram']._streamlines) # Copying LazyTractogram, creates new internal LazyDict objects, - # but coroutines contained in it should be the same. + # but generator functions contained in it should stay the same. assert_true(tractogram._data_per_streamline is not DATA['lazy_tractogram']._data_per_streamline) assert_true(tractogram._data_per_point diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index e620682296..491aa29687 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -1,4 +1,5 @@ import copy +import numbers import numpy as np import collections from warnings import warn @@ -8,6 +9,16 @@ from .array_sequence import ArraySequence +def is_data_dict(obj): + """ Tells if obj is a :class:`DataDict`. """ + return hasattr(obj, 'store') + + +def is_lazy_dict(obj): + """ Tells if obj is a :class:`LazyDict`. """ + return is_data_dict(obj) and callable(obj.store.values()[0]) + + class DataDict(collections.MutableMapping): """ Dictionary that makes sure data are 2D array. @@ -24,7 +35,7 @@ def __init__(self, tractogram, *args, **kwargs): # Use update to set the keys. if len(args) == 1: if isinstance(args[0], DataDict): - self.update(dict(args[0].store.items())) + self.update(**args[0]) elif args[0] is None: return else: @@ -75,7 +86,7 @@ class DataPerStreamlineDict(DataDict): dictionary. """ def __setitem__(self, key, value): - value = np.asarray(value) + value = np.asarray(list(value)) if value.ndim == 1 and value.dtype != object: # Reshape without copy @@ -121,19 +132,29 @@ def __setitem__(self, key, value): class LazyDict(DataDict): - """ Dictionary of coroutines with lazy evaluation. + """ Dictionary of generator functions. This container behaves like an dictionary but it makes sure its elements - are callable objects and assumed to be coroutines yielding values. When - getting the element associated to a given key, the element (i.e. a - coroutine) is first called before being returned. + are callable objects and assumed to be generator function yielding values. + When getting the element associated to a given key, the element (i.e. a + generator function) is first called before being returned. """ + def __init__(self, tractogram, *args, **kwargs): + if len(args) == 1 and isinstance(args[0], LazyDict): + # Copy the generator functions. + self.tractogram = tractogram + self.store = dict() + self.update(**args[0].store) + return + + super(LazyDict, self).__init__(tractogram, *args, **kwargs) + def __getitem__(self, key): return self.store[key]() def __setitem__(self, key, value): if value is not None and not callable(value): - raise TypeError("`value` must be a coroutine or None.") + raise TypeError("`value` must be a generator function or None.") self.store[key] = value @@ -222,6 +243,9 @@ def data_per_streamline(self): @data_per_streamline.setter def data_per_streamline(self, value): + # if is_lazy_dict(value): + # self._data_per_streamline = DataPerStreamlineDict(self, **value.items()) + # else: self._data_per_streamline = DataPerStreamlineDict(self, value) @property @@ -230,6 +254,9 @@ def data_per_point(self): @data_per_point.setter def data_per_point(self, value): + # if is_lazy_dict(value): + # self._data_per_point = DataPerPointDict(self, **value.items()) + # else: self._data_per_point = DataPerPointDict(self, value) def get_affine_to_rasmm(self): @@ -251,7 +278,7 @@ def __getitem__(self, idx): for key in self.data_per_point: data_per_point[key] = self.data_per_point[key][idx] - if isinstance(idx, (int, np.integer)): + if isinstance(idx, (numbers.Integral, np.integer)): return TractogramItem(pts, data_per_streamline, data_per_point) return Tractogram(pts, data_per_streamline, data_per_point) @@ -270,7 +297,7 @@ def apply_affine(self, affine, lazy=False): Parameters ---------- - affine : ndarray shape (4, 4) + affine : ndarray of shape (4, 4) Transformation that will be applied to every streamline. lazy_load : {False, True}, optional If True, streamlines are *not* transformed in-place and a @@ -326,19 +353,25 @@ def __init__(self, streamlines=None, """ Parameters ---------- - streamlines : coroutine yielding ndarrays of shape (Nt,3) (optional) - Function yielding streamlines. One streamline is an ndarray of - shape (Nt,3) where Nt is the number of points of streamline t. - data_per_streamline : dict of coroutines yielding ndarrays of shape (P,) (optional) - Function yielding properties for a particular streamline t. The - properties are represented as an ndarray of shape (P,) where P is - the number of properties associated to each streamline. - data_per_point : dict of coroutines yielding ndarrays of shape (Nt,M) (optional) - Function yielding scalars for a particular streamline t. The - scalars are represented as an ndarray of shape (Nt,M) where Nt - is the number of points of that streamline t and M is the number - of scalars associated to each point (excluding the three - coordinates). + streamlines : generator function yielding, optional + Generator function yielding streamlines. One streamline is an + ndarray of shape ($N_t$, 3) where $N_t$ is the number of points of + streamline $t$. + data_per_streamline : dict of generator functions, optional + Dictionary where the items are (str, generator function). + Each key represents an information $i$ to be kept along side every + streamline, and its associated value is a generator function + yielding that information via ndarrays of shape ($P_i$,) where + $P_i$ is the number scalar values to store for that particular + information $i$. + data_per_point : dict of generator functions, optional + Dictionary where the items are (str, generator function). + Each key represents an information $i$ to be kept along side every + point of every streamline, and its associated value is a generator + function yielding that information via ndarrays of shape + ($N_t$, $M_i$) where $N_t$ is the number of points for a particular + streamline $t$ and $M_i$ is the number scalar values to store for + that particular information $i$. """ super(LazyTractogram, self).__init__(streamlines, data_per_streamline, @@ -379,15 +412,16 @@ def from_tractogram(cls, tractogram): @classmethod def create_from(cls, data_func): - """ Creates a :class:`LazyTractogram` from a coroutine yielding - :class:`TractogramItem` objects. + """ Creates an instance from a generator function. + + The generator function must yield :class:`TractogramItem` objects. Parameters ---------- - data_func : coroutine yielding :class:`TractogramItem` objects - A function that whenever it is called starts yielding - :class:`TractogramItem` objects that should be part of this - LazyTractogram. + data_func : generator function yielding :class:`TractogramItem` objects + Generator function that whenever it is called starts yielding + :class:`TractogramItem` objects that will be used to instantiate a + :class:`LazyTractogram`. Returns ------- @@ -395,7 +429,7 @@ def create_from(cls, data_func): New lazy tractogram. """ if not callable(data_func): - raise TypeError("`data_func` must be a coroutine.") + raise TypeError("`data_func` must be a generator function.") lazy_tractogram = cls() lazy_tractogram._data = data_func @@ -445,7 +479,7 @@ def _apply_affine(): @streamlines.setter def streamlines(self, value): if value is not None and not callable(value): - raise TypeError("`streamlines` must be a coroutine.") + raise TypeError("`streamlines` must be a generator function.") self._streamlines = value @@ -455,9 +489,6 @@ def data_per_streamline(self): @data_per_streamline.setter def data_per_streamline(self, value): - if value is None: - value = {} - self._data_per_streamline = LazyDict(self, value) @property @@ -466,9 +497,6 @@ def data_per_point(self): @data_per_point.setter def data_per_point(self, value): - if value is None: - value = {} - self._data_per_point = LazyDict(self, value) @property From 909e0b61d41f377c19a1b2645c660fce0b60b235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 24 Feb 2016 00:57:55 -0500 Subject: [PATCH 087/135] Added Tractogram.to_world and LazyTractogram.to_world methods --- nibabel/streamlines/tests/test_tractogram.py | 108 +++++++++++++++---- nibabel/streamlines/tractogram.py | 69 +++++++++--- 2 files changed, 142 insertions(+), 35 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index a5bd1c1a03..77d9cd67e5 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -389,27 +389,22 @@ def test_tractogram_apply_affine(self): # Apply the affine to the streamline in a lazy manner. transformed_tractogram = tractogram.apply_affine(affine, lazy=True) assert_true(type(transformed_tractogram) is LazyTractogram) - assert_true(check_iteration(transformed_tractogram)) - assert_equal(len(transformed_tractogram), len(DATA['streamlines'])) - for s1, s2 in zip(transformed_tractogram.streamlines, - DATA['streamlines']): - assert_array_almost_equal(s1, s2*scaling) - - for s1, s2 in zip(transformed_tractogram.streamlines, - tractogram.streamlines): - assert_array_almost_equal(s1, s2*scaling) - + check_tractogram(transformed_tractogram, + streamlines=[s*scaling for s in DATA['streamlines']], + data_per_streamline=DATA['data_per_streamline'], + data_per_point=DATA['data_per_point']) assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), np.dot(np.eye(4), np.linalg.inv(affine))) + # Make sure streamlines of the original tractogram have not been modified. + assert_arrays_equal(tractogram.streamlines, DATA['streamlines']) # Apply the affine to the streamlines in-place. transformed_tractogram = tractogram.apply_affine(affine) assert_true(transformed_tractogram is tractogram) - assert_true(check_iteration(transformed_tractogram)) - assert_equal(len(transformed_tractogram), len(DATA['streamlines'])) - for s1, s2 in zip(transformed_tractogram.streamlines, - DATA['streamlines']): - assert_array_almost_equal(s1, s2*scaling) + check_tractogram(tractogram, + streamlines=[s*scaling for s in DATA['streamlines']], + data_per_streamline=DATA['data_per_streamline'], + data_per_point=DATA['data_per_point']) # Apply affine again and check the affine_to_rasmm. transformed_tractogram = tractogram.apply_affine(affine) @@ -417,6 +412,50 @@ def test_tractogram_apply_affine(self): np.dot(np.eye(4), np.dot(np.linalg.inv(affine), np.linalg.inv(affine)))) + # Check that applying an affine and its inverse give us back the + # original streamlines. + tractogram = DATA['tractogram'].copy() + affine = np.random.RandomState(1234).randn(4, 4) + affine[-1] = [0, 0, 0, 1] # Remove perspective projection. + + tractogram.apply_affine(affine) + tractogram.apply_affine(np.linalg.inv(affine)) + assert_array_almost_equal(tractogram.get_affine_to_rasmm(), np.eye(4)) + for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + + def test_tractogram_to_world(self): + tractogram = DATA['tractogram'].copy() + affine = np.random.RandomState(1234).randn(4, 4) + affine[-1] = [0, 0, 0, 1] # Remove perspective projection. + + # Apply the affine to the streamlines, then bring them back + # to world space in a lazy manner. + transformed_tractogram = tractogram.apply_affine(affine) + assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), + np.linalg.inv(affine)) + + tractogram_world = transformed_tractogram.to_world(lazy=True) + assert_true(type(tractogram_world) is LazyTractogram) + assert_array_almost_equal(tractogram_world.get_affine_to_rasmm(), + np.eye(4)) + for s1, s2 in zip(tractogram_world.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + + # Bring them back streamlines to world space in a in-place manner. + tractogram_world = transformed_tractogram.to_world() + assert_true(tractogram_world is tractogram) + assert_array_almost_equal(tractogram.get_affine_to_rasmm(), np.eye(4)) + for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + + # Calling to_world twice should do nothing. + tractogram_world2 = transformed_tractogram.to_world() + assert_true(tractogram_world2 is tractogram) + assert_array_almost_equal(tractogram.get_affine_to_rasmm(), np.eye(4)) + for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + class TestLazyTractogram(unittest.TestCase): @@ -531,18 +570,47 @@ def test_lazy_tractogram_apply_affine(self): tractogram = DATA['lazy_tractogram'].copy() - tractogram.apply_affine(affine) - assert_true(check_iteration(tractogram)) - assert_equal(len(tractogram), len(DATA['streamlines'])) - for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): - assert_array_almost_equal(s1, s2*scaling) + transformed_tractogram = tractogram.apply_affine(affine) + assert_true(transformed_tractogram is tractogram) + assert_array_equal(tractogram._affine_to_apply, affine) + assert_array_equal(tractogram.get_affine_to_rasmm(), + np.dot(np.eye(4), np.linalg.inv(affine))) + check_tractogram(tractogram, + streamlines=[s*scaling for s in DATA['streamlines']], + data_per_streamline=DATA['data_per_streamline'], + data_per_point=DATA['data_per_point']) # Apply affine again and check the affine_to_rasmm. transformed_tractogram = tractogram.apply_affine(affine) + assert_array_equal(tractogram._affine_to_apply, np.dot(affine, affine)) assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), np.dot(np.eye(4), np.dot(np.linalg.inv(affine), np.linalg.inv(affine)))) + def test_tractogram_to_world(self): + tractogram = DATA['lazy_tractogram'].copy() + affine = np.random.RandomState(1234).randn(4, 4) + affine[-1] = [0, 0, 0, 1] # Remove perspective projection. + + # Apply the affine to the streamlines, then bring them back + # to world space in a lazy manner. + tractogram.apply_affine(affine) + assert_array_equal(tractogram.get_affine_to_rasmm(), + np.linalg.inv(affine)) + + tractogram_world = tractogram.to_world() + assert_true(tractogram_world is tractogram) + assert_array_almost_equal(tractogram.get_affine_to_rasmm(), + np.eye(4)) + for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + + # Calling to_world twice should do nothing. + tractogram.to_world() + assert_array_almost_equal(tractogram.get_affine_to_rasmm(), np.eye(4)) + for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + def test_lazy_tractogram_copy(self): # Create a copy of the lazy tractogram. tractogram = DATA['lazy_tractogram'].copy() diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 491aa29687..411d3123fd 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -198,6 +198,12 @@ class Tractogram(object): Tractogram objects have three main properties: `streamlines`, `data_per_streamline` and `data_per_point`. + Streamlines of a tractogram can be in any coordinate system of your + choice as long as you provide the correct `affine_to_rasmm` matrix, at + construction time, that brings the streamlines back to *RAS+*, *mm* space, + where the coordinates (0,0,0) corresponds to the center of the voxel + (opposed to a corner). + """ def __init__(self, streamlines=None, data_per_streamline=None, @@ -243,9 +249,6 @@ def data_per_streamline(self): @data_per_streamline.setter def data_per_streamline(self, value): - # if is_lazy_dict(value): - # self._data_per_streamline = DataPerStreamlineDict(self, **value.items()) - # else: self._data_per_streamline = DataPerStreamlineDict(self, value) @property @@ -254,9 +257,6 @@ def data_per_point(self): @data_per_point.setter def data_per_point(self, value): - # if is_lazy_dict(value): - # self._data_per_point = DataPerPointDict(self, **value.items()) - # else: self._data_per_point = DataPerPointDict(self, value) def get_affine_to_rasmm(self): @@ -333,6 +333,28 @@ def apply_affine(self, affine, lazy=False): return self + def to_world(self, lazy=False): + """ Brings the streamlines to world space (i.e. RAS+ and mm). + + If `lazy` is not specified, this is performed *in-place*. + + Parameters + ---------- + lazy_load : {False, True}, optional + If True, streamlines are *not* transformed in-place and a + :class:`LazyTractogram` object is returned. Otherwise, streamlines + are modified in-place. + + Returns + ------- + tractogram : :class:`Tractogram` or :class:`LazyTractogram` object + Tractogram where the streamlines have been sent to world space. + If the `lazy` option is true, it returns a :class:`LazyTractogram` + object, otherwise it returns a reference to this + :class:`Tractogram` object with updated streamlines. + """ + return self.apply_affine(self._affine_to_rasmm, lazy=lazy) + class LazyTractogram(Tractogram): """ Class containing information about streamlines. @@ -394,17 +416,21 @@ def from_tractogram(cls, tractogram): lazy_tractogram : :class:`LazyTractogram` object New lazy tractogram. """ - data_per_streamline = {} - for key, value in tractogram.data_per_streamline.items(): - data_per_streamline[key] = lambda: value + lazy_tractogram = cls(lambda: tractogram.streamlines.copy()) - data_per_point = {} - for key, value in tractogram.data_per_point.items(): - data_per_point[key] = lambda: value + # Set data_per_streamline using data_func + def _gen(key): + return lambda: iter(tractogram.data_per_streamline[key]) + + for k in tractogram.data_per_streamline: + lazy_tractogram._data_per_streamline[k] = _gen(k) - lazy_tractogram = cls(lambda: tractogram.streamlines.copy(), - data_per_streamline, - data_per_point) + # Set data_per_point using data_func + def _gen(key): + return lambda: iter(tractogram.data_per_point[key]) + + for k in tractogram.data_per_point: + lazy_tractogram._data_per_point[k] = _gen(k) lazy_tractogram._nb_streamlines = len(tractogram) lazy_tractogram._affine_to_rasmm = tractogram.get_affine_to_rasmm() @@ -584,3 +610,16 @@ def apply_affine(self, affine): self._affine_to_rasmm = np.dot(self._affine_to_rasmm, np.linalg.inv(affine)) return self + + def to_world(self): + """ Brings the streamlines to world space (i.e. RAS+ and mm). + + The transformation will be applied just before returning the + streamlines. + + Returns + ------- + lazy_tractogram : :class:`LazyTractogram` object + Reference to this instance of :class:`LazyTractogram`. + """ + return self.apply_affine(self._affine_to_rasmm) From 111f53a6cbd0e31a2702cca828ac87c123173cb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 24 Feb 2016 01:32:23 -0500 Subject: [PATCH 088/135] Clarified Tractogram and LazyTractogram docstring --- nibabel/streamlines/tractogram.py | 108 ++++++++++++++++++++++-------- 1 file changed, 80 insertions(+), 28 deletions(-) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 411d3123fd..70e6af7323 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -193,10 +193,7 @@ def __len__(self): class Tractogram(object): - """ Class containing information about streamlines. - - Tractogram objects have three main properties: `streamlines`, - `data_per_streamline` and `data_per_point`. + """ Container for streamlines and their data information. Streamlines of a tractogram can be in any coordinate system of your choice as long as you provide the correct `affine_to_rasmm` matrix, at @@ -204,6 +201,25 @@ class Tractogram(object): where the coordinates (0,0,0) corresponds to the center of the voxel (opposed to a corner). + Attributes + ---------- + streamlines : :class:`ArraySequence` object + Sequence of $T$ streamlines. Each streamline is an ndarray of + shape ($N_t$, 3) where $N_t$ is the number of points of + streamline $t$. + data_per_streamline : dict of 2D arrays + Dictionary where the items are (str, 2D array). + Each key represents an information $i$ to be kept along side every + streamline, and its associated value is a 2D array of shape + ($T$, $P_i$) where $T$ is the number of streamlines and $P_i$ is + the number scalar values to store for that particular information $i$. + data_per_point : dict of :class:`ArraySequence` objects + Dictionary where the items are (str, :class:`ArraySequence`). + Each key represents an information $i$ to be kept along side every + point of every streamline, and its associated value is an iterable + of ndarrays of shape ($N_t$, $M_i$) where $N_t$ is the number of + points for a particular streamline $t$ and $M_i$ is the number + scalar values to store for that particular information $i$. """ def __init__(self, streamlines=None, data_per_streamline=None, @@ -212,19 +228,23 @@ def __init__(self, streamlines=None, """ Parameters ---------- - streamlines : list of ndarray of shape (Nt, 3) (optional) - Sequence of T streamlines. One streamline is an ndarray of - shape (Nt, 3) where Nt is the number of points of streamline t. - data_per_streamline : dict of list of ndarray of shape (P,) (optional) - Sequence of T ndarrays of shape (P,) where T is the number of - streamlines defined by `streamlines`, P is the number of - properties associated to each streamline. - data_per_point : dict of list of ndarray of shape (Nt, M) (optional) - Sequence of T ndarrays of shape (Nt, M) where T is the number - of streamlines defined by `streamlines`, Nt is the number of - points for a particular streamline t and M is the number of - scalars associated to each point (excluding the three - coordinates). + streamlines : iterable of ndarrays or :class:`ArraySequence`, optional + Sequence of $T$ streamlines. Each streamline is an ndarray of + shape ($N_t$, 3) where $N_t$ is the number of points of + streamline $t$. + data_per_streamline : dict of iterable of ndarrays, optional + Dictionary where the items are (str, iterable). + Each key represents an information $i$ to be kept along side every + streamline, and its associated value is an iterable of ndarrays of + shape ($P_i$,) where $P_i$ is the number scalar values to store + for that particular information $i$. + data_per_point : dict of iterable of ndarrays, optional + Dictionary where the items are (str, iterable). + Each key represents an information $i$ to be kept along side every + point of every streamline, and its associated value is an iterable + of ndarrays of shape ($N_t$, $M_i$) where $N_t$ is the number of + points for a particular streamline $t$ and $M_i$ is the number + scalar values to store for that particular information $i$. affine_to_rasmm : ndarray of shape (4, 4) Transformation matrix that brings the streamlines contained in this tractogram to *RAS+* and *mm* space where coordinate (0,0,0) @@ -357,26 +377,53 @@ def to_world(self, lazy=False): class LazyTractogram(Tractogram): - """ Class containing information about streamlines. + """ Lazy container for streamlines and their data information. + + This container behaves lazily as it uses generator functions to manage + streamlines and their data information. This container is thus memory + friendly since it doesn't require having all those data loaded in memory. + + Streamlines of a lazy tractogram can be in any coordinate system of your + choice as long as you provide the correct `affine_to_rasmm` matrix, at + construction time, that brings the streamlines back to *RAS+*, *mm* space, + where the coordinates (0,0,0) corresponds to the center of the voxel + (opposed to a corner). - Tractogram objects have four main properties: `header`, `streamlines`, - `scalars` and `properties`. Tractogram objects are iterable and - produce tuple of `streamlines`, `scalars` and `properties` for each - streamline. + Attributes + ---------- + streamlines : generator function + Generator function yielding streamlines. Each streamline is an + ndarray of shape ($N_t$, 3) where $N_t$ is the number of points of + streamline $t$. + data_per_streamline : :class:`LazyDict` object + Dictionary where the items are (str, instantiated generator). + Each key represents an information $i$ to be kept along side every + streamline, and its associated value is a generator function + yielding that information via ndarrays of shape ($P_i$,) where + $P_i$ is the number scalar values to store for that particular + information $i$. + data_per_point : :class:`LazyDict` object + Dictionary where the items are (str, instantiated generator). + Each key represents an information $i$ to be kept along side every + point of every streamline, and its associated value is a generator + function yielding that information via ndarrays of shape + ($N_t$, $M_i$) where $N_t$ is the number of points for a particular + streamline $t$ and $M_i$ is the number scalar values to store for + that particular information $i$. Notes ----- - If provided, `scalars` and `properties` must yield the same number of - values as `streamlines`. + LazyTractogram objects do not support indexing currently. """ def __init__(self, streamlines=None, data_per_streamline=None, - data_per_point=None): + data_per_point=None, + affine_to_rasmm=np.eye(4)): """ Parameters ---------- - streamlines : generator function yielding, optional - Generator function yielding streamlines. One streamline is an + streamlines : generator function, optional + Generator function yielding streamlines. Each streamline is an ndarray of shape ($N_t$, 3) where $N_t$ is the number of points of streamline $t$. data_per_streamline : dict of generator functions, optional @@ -394,10 +441,15 @@ def __init__(self, streamlines=None, ($N_t$, $M_i$) where $N_t$ is the number of points for a particular streamline $t$ and $M_i$ is the number scalar values to store for that particular information $i$. + affine_to_rasmm : ndarray of shape (4, 4) + Transformation matrix that brings the streamlines contained in + this tractogram to *RAS+* and *mm* space where coordinate (0,0,0) + refers to the center of the voxel. """ super(LazyTractogram, self).__init__(streamlines, data_per_streamline, - data_per_point) + data_per_point, + affine_to_rasmm) self._nb_streamlines = None self._data = None self._affine_to_apply = np.eye(4) From ebe702ca9b3991370405afb0670e37cbf014ee8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 29 Feb 2016 10:45:41 -0500 Subject: [PATCH 089/135] Removed ref param in nibabel.streamlines.load --- nibabel/streamlines/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index 51696979e7..b4c740024e 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -59,7 +59,7 @@ def detect_format(fileobj): return None -def load(fileobj, lazy_load=False, ref=None): +def load(fileobj, lazy_load=False): """ Loads streamlines in *RAS+* and *mm* space from a file-like object. Parameters From a91cb475d7b962edd5201fe04a9c85ddf9f0896b Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Wed, 2 Mar 2016 21:32:56 -0800 Subject: [PATCH 090/135] RF: refactor trk file object str Refactor using big format string and compiled dictionary of variables. --- nibabel/streamlines/trk.py | 63 +++++++++++++++++++------------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index bd501c9114..2bb619ef4e 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -7,6 +7,7 @@ import struct import warnings import itertools +import string import numpy as np import nibabel as nib @@ -781,34 +782,34 @@ def __str__(self): info : string Header information relevant to the TRK format. """ - hdr = self.header - - info = "" - info += "\nMAGIC NUMBER: {0}".format(hdr[Field.MAGIC_NUMBER]) - info += "\nv.{0}".format(hdr['version']) - info += "\ndim: {0}".format(hdr[Field.DIMENSIONS]) - info += "\nvoxel_sizes: {0}".format(hdr[Field.VOXEL_SIZES]) - info += "\norgin: {0}".format(hdr[Field.ORIGIN]) - info += "\nnb_scalars: {0}".format(hdr[Field.NB_SCALARS_PER_POINT]) - info += "\nscalar_name:\n {0}".format( - "\n".join(map(asstr, hdr['scalar_name']))) - info += "\nnb_properties: {0}".format( - hdr[Field.NB_PROPERTIES_PER_STREAMLINE]) - info += "\nproperty_name:\n {0}".format( - "\n".join(map(asstr, hdr['property_name']))) - info += "\nvox_to_world: {0}".format(hdr[Field.VOXEL_TO_RASMM]) - info += "\nvoxel_order: {0}".format(hdr[Field.VOXEL_ORDER]) - info += "\nimage_orientation_patient: {0}".format( - hdr['image_orientation_patient']) - info += "\npad1: {0}".format(hdr['pad1']) - info += "\npad2: {0}".format(hdr['pad2']) - info += "\ninvert_x: {0}".format(hdr['invert_x']) - info += "\ninvert_y: {0}".format(hdr['invert_y']) - info += "\ninvert_z: {0}".format(hdr['invert_z']) - info += "\nswap_xy: {0}".format(hdr['swap_xy']) - info += "\nswap_yz: {0}".format(hdr['swap_yz']) - info += "\nswap_zx: {0}".format(hdr['swap_zx']) - info += "\nn_count: {0}".format(hdr[Field.NB_STREAMLINES]) - info += "\nhdr_size: {0}".format(hdr['hdr_size']) - - return info + vars = self.header.copy() + for attr in dir(Field): + if attr[0] in string.ascii_uppercase: + hdr_field = getattr(Field, attr) + if hdr_field in vars: + vars[attr] = vars[hdr_field] + vars['scalar_names'] = '\n'.join(map(asstr, vars['scalar_name'])) + vars['property_names'] = "\n".join(map(asstr, vars['property_name'])) + return """\ +MAGIC NUMBER: {MAGIC_NUMBER} +v.{version} +dim: {DIMENSIONS} +voxel_sizes: {VOXEL_SIZES} +orgin: {ORIGIN} +nb_scalars: {NB_SCALARS_PER_POINT} +scalar_name:\n {scalar_names} +nb_properties: {NB_PROPERTIES_PER_STREAMLINE} +property_name:\n {property_names} +vox_to_world: {VOXEL_TO_RASMM} +voxel_order: {VOXEL_ORDER} +image_orientation_patient: {image_orientation_patient} +pad1: {pad1} +pad2: {pad2} +invert_x: {invert_x} +invert_y: {invert_y} +invert_z: {invert_z} +swap_xy: {swap_xy} +swap_yz: {swap_yz} +swap_zx: {swap_zx} +n_count: {NB_STREAMLINES} +hdr_size: {hdr_size}""".format(**vars) From e331486cea235ecdc0444f510b452153f6a34017 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 30 Mar 2016 13:49:30 -0400 Subject: [PATCH 091/135] Rebased and addressed some comments --- nibabel/streamlines/tractogram.py | 18 ++++++++++-------- nibabel/streamlines/tractogram_file.py | 12 ++++++------ nibabel/streamlines/trk.py | 4 +--- 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 70e6af7323..9e8ffa9fe1 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -245,10 +245,11 @@ def __init__(self, streamlines=None, of ndarrays of shape ($N_t$, $M_i$) where $N_t$ is the number of points for a particular streamline $t$ and $M_i$ is the number scalar values to store for that particular information $i$. - affine_to_rasmm : ndarray of shape (4, 4) + affine_to_rasmm : ndarray of shape (4, 4), optional Transformation matrix that brings the streamlines contained in this tractogram to *RAS+* and *mm* space where coordinate (0,0,0) - refers to the center of the voxel. + refers to the center of the voxel. By default, the streamlines + are assumed to be already in *RAS+* and *mm* space. """ self.streamlines = streamlines self.data_per_streamline = data_per_streamline @@ -608,12 +609,13 @@ def __getitem__(self, idx): raise NotImplementedError('`LazyTractogram` does not support indexing.') def __iter__(self): - i = 0 - for i, tractogram_item in enumerate(self.data, start=1): + count = 0 + for tractogram_item in self.data: yield tractogram_item + count += 1 # Keep how many streamlines there are in this tractogram. - self._nb_streamlines = i + self._nb_streamlines = count def __len__(self): # Check if we know how many streamlines there are. @@ -642,13 +644,13 @@ def copy(self): def apply_affine(self, affine): """ Applies an affine transformation to the streamlines. - The transformation will be applied just before returning the - streamlines. + The transformation given by the `affine` matrix is applied after any + other pending transformations to the streamline points. Parameters ---------- affine : 2D array (4,4) - Transformation that will be applied on each streamline. + Transformation matrix that will be applied on each streamline. Returns ------- diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index 9a271df4ee..4e77aba342 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -5,19 +5,19 @@ class ExtensionWarning(Warning): - pass + """ Base class for warnings about tractogram file extension. """ class HeaderWarning(Warning): - pass + """ Base class for warnings about tractogram file header. """ class HeaderError(Exception): - pass + """ Raised when a tractogram file header contains invalid information. """ class DataError(Exception): - pass + """ Raised when data is missing or inconsistent in a tractogram file. """ class abstractclassmethod(classmethod): @@ -71,12 +71,12 @@ def get_magic_number(cls): @abstractclassmethod def support_data_per_point(cls): - """ Tells if this tractogram format supports saving data per point. """ + """ Tells if this format supports saving data per point. """ raise NotImplementedError() @abstractclassmethod def support_data_per_streamline(cls): - """ Tells if this tractogram format supports saving data per streamline. """ + """ Tells if this format supports saving data per streamline. """ raise NotImplementedError() @abstractclassmethod diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 2bb619ef4e..7a788d4b03 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -450,7 +450,7 @@ class TrkFile(TractogramFile): back on save. """ - # Contants + # Constants MAGIC_NUMBER = b"TRACK" HEADER_SIZE = 1000 READ_BUFFER_SIZE = 87382 * 4 # About 4 Mb when no scalars nor properties. @@ -514,8 +514,6 @@ def is_correct_format(cls, fileobj): f.seek(-5, os.SEEK_CUR) return magic_number == cls.MAGIC_NUMBER - return False - @classmethod def _create_arraysequence_from_generator(cls, gen): """ Creates a ArraySequence object from a generator yielding tuples of From aadcf57d0a719d2b8e2584fb5bf5334c627e4f69 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 30 Mar 2016 16:50:52 -0400 Subject: [PATCH 092/135] Fixed affine transformation of LazyTractogram --- nibabel/streamlines/tests/test_tractogram.py | 33 ++++---- nibabel/streamlines/tractogram.py | 46 +++++++---- nibabel/streamlines/trk.py | 85 +++++++++++--------- 3 files changed, 94 insertions(+), 70 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 77d9cd67e5..3652d14354 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -571,18 +571,21 @@ def test_lazy_tractogram_apply_affine(self): tractogram = DATA['lazy_tractogram'].copy() transformed_tractogram = tractogram.apply_affine(affine) - assert_true(transformed_tractogram is tractogram) - assert_array_equal(tractogram._affine_to_apply, affine) - assert_array_equal(tractogram.get_affine_to_rasmm(), + assert_true(transformed_tractogram is not tractogram) + assert_array_equal(tractogram._affine_to_apply, np.eye(4)) + assert_array_equal(tractogram.get_affine_to_rasmm(), np.eye(4)) + assert_array_equal(transformed_tractogram._affine_to_apply, affine) + assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), np.dot(np.eye(4), np.linalg.inv(affine))) - check_tractogram(tractogram, + check_tractogram(transformed_tractogram, streamlines=[s*scaling for s in DATA['streamlines']], data_per_streamline=DATA['data_per_streamline'], data_per_point=DATA['data_per_point']) # Apply affine again and check the affine_to_rasmm. - transformed_tractogram = tractogram.apply_affine(affine) - assert_array_equal(tractogram._affine_to_apply, np.dot(affine, affine)) + transformed_tractogram = transformed_tractogram.apply_affine(affine) + assert_array_equal(transformed_tractogram._affine_to_apply, + np.dot(affine, affine)) assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), np.dot(np.eye(4), np.dot(np.linalg.inv(affine), np.linalg.inv(affine)))) @@ -594,21 +597,21 @@ def test_tractogram_to_world(self): # Apply the affine to the streamlines, then bring them back # to world space in a lazy manner. - tractogram.apply_affine(affine) - assert_array_equal(tractogram.get_affine_to_rasmm(), + transformed_tractogram = tractogram.apply_affine(affine) + assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), np.linalg.inv(affine)) - tractogram_world = tractogram.to_world() - assert_true(tractogram_world is tractogram) - assert_array_almost_equal(tractogram.get_affine_to_rasmm(), + tractogram_world = transformed_tractogram.to_world() + assert_true(tractogram_world is not transformed_tractogram) + assert_array_almost_equal(tractogram_world.get_affine_to_rasmm(), np.eye(4)) - for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): + for s1, s2 in zip(tractogram_world.streamlines, DATA['streamlines']): assert_array_almost_equal(s1, s2) # Calling to_world twice should do nothing. - tractogram.to_world() - assert_array_almost_equal(tractogram.get_affine_to_rasmm(), np.eye(4)) - for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): + tractogram_world = tractogram_world.to_world() + assert_array_almost_equal(tractogram_world.get_affine_to_rasmm(), np.eye(4)) + for s1, s2 in zip(tractogram_world.streamlines, DATA['streamlines']): assert_array_almost_equal(s1, s2) def test_lazy_tractogram_copy(self): diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 9e8ffa9fe1..11e2a931f1 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -320,7 +320,7 @@ def apply_affine(self, affine, lazy=False): ---------- affine : ndarray of shape (4, 4) Transformation that will be applied to every streamline. - lazy_load : {False, True}, optional + lazy : {False, True}, optional If True, streamlines are *not* transformed in-place and a :class:`LazyTractogram` object is returned. Otherwise, streamlines are modified in-place. @@ -336,8 +336,7 @@ def apply_affine(self, affine, lazy=False): """ if lazy: lazy_tractogram = LazyTractogram.from_tractogram(self) - lazy_tractogram.apply_affine(affine) - return lazy_tractogram + return lazy_tractogram.apply_affine(affine) if len(self.streamlines) == 0: return self @@ -361,7 +360,7 @@ def to_world(self, lazy=False): Parameters ---------- - lazy_load : {False, True}, optional + lazy : {False, True}, optional If True, streamlines are *not* transformed in-place and a :class:`LazyTractogram` object is returned. Otherwise, streamlines are modified in-place. @@ -641,7 +640,7 @@ def copy(self): tractogram._affine_to_apply = self._affine_to_apply.copy() return tractogram - def apply_affine(self, affine): + def apply_affine(self, affine, lazy=True): """ Applies an affine transformation to the streamlines. The transformation given by the `affine` matrix is applied after any @@ -651,29 +650,46 @@ def apply_affine(self, affine): ---------- affine : 2D array (4,4) Transformation matrix that will be applied on each streamline. + lazy : True, optional + Should always be True for :class:`LazyTractogram` object. Doing + otherwise will raise a ValueError. Returns ------- lazy_tractogram : :class:`LazyTractogram` object - Reference to this instance of :class:`LazyTractogram`. + A copy of this :class:`LazyTractogram` instance but with a + transformation to be applied on the streamlines. """ + if not lazy: + msg = "LazyTractogram only supports lazy transformations." + raise ValueError(msg) + + tractogram = self.copy() # New instance. + # Update the affine that will be applied when returning streamlines. - self._affine_to_apply = np.dot(affine, self._affine_to_apply) + tractogram._affine_to_apply = np.dot(affine, self._affine_to_apply) # Update the affine that brings back the streamlines to RASmm. - self._affine_to_rasmm = np.dot(self._affine_to_rasmm, - np.linalg.inv(affine)) - return self + tractogram._affine_to_rasmm = np.dot(self._affine_to_rasmm, + np.linalg.inv(affine)) + return tractogram - def to_world(self): + def to_world(self, lazy=True): """ Brings the streamlines to world space (i.e. RAS+ and mm). - The transformation will be applied just before returning the - streamlines. + The transformation is applied after any other pending transformations + to the streamline points. + + Parameters + ---------- + lazy : True, optional + Should always be True for :class:`LazyTractogram` object. Doing + otherwise will raise a ValueError. Returns ------- lazy_tractogram : :class:`LazyTractogram` object - Reference to this instance of :class:`LazyTractogram`. + A copy of this :class:`LazyTractogram` instance but with a + transformation to be applied on the streamlines. """ - return self.apply_affine(self._affine_to_rasmm) + return self.apply_affine(self._affine_to_rasmm, lazy=lazy) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 7a788d4b03..34c55b9897 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -270,6 +270,43 @@ def __init__(self, fileobj, header): self.file.write(self.header.tostring()) + # `Tractogram` streamlines are in RAS+ and mm space, we will compute + # the affine matrix that will bring them back to 'voxelmm' as required + # by the TRK format. + affine = np.eye(4) + + # Applied the inverse of the affine found in the TRK header. + # rasmm -> voxel + affine = np.dot(np.linalg.inv(self.header[Field.VOXEL_TO_RASMM]), + affine) + + # If the voxel order implied by the affine does not match the voxel + # order in the TRK header, change the orientation. + # voxel (affine) -> voxel (header) + header_ornt = asstr(self.header[Field.VOXEL_ORDER]) + affine_ornt = "".join(aff2axcodes(self.header[Field.VOXEL_TO_RASMM])) + header_ornt = axcodes2ornt(header_ornt) + affine_ornt = axcodes2ornt(affine_ornt) + ornt = nib.orientations.ornt_transform(affine_ornt, header_ornt) + M = nib.orientations.inv_ornt_aff(ornt, self.header[Field.DIMENSIONS]) + affine = np.dot(M, affine) + + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas `Tractogram` streamlines assume (0,0,0) is the + # center of the voxel. Thus, streamlines are shifted of half a voxel. + offset = np.eye(4) + offset[:-1, -1] += 0.5 + affine = np.dot(offset, affine) + + # Finally send the streamlines in mm space. + # voxel -> voxelmm + scale = np.eye(4) + scale[range(3), range(3)] *= self.header[Field.VOXEL_SIZES] + affine = np.dot(scale, affine) + + # The TRK format uses float32 as the data type for points. + self._affine_rasmm_to_voxmm = affine.astype(np.float32) + def write(self, tractogram): i4_dtype = np.dtype(" voxel - affine = np.dot(np.linalg.inv(self.header[Field.VOXEL_TO_RASMM]), - affine) - - # If the voxel order implied by the affine does not match the voxel - # order in the TRK header, change the orientation. - # voxel (affine) -> voxel (header) - header_ornt = asstr(self.header[Field.VOXEL_ORDER]) - affine_ornt = "".join(aff2axcodes(self.header[Field.VOXEL_TO_RASMM])) - header_ornt = axcodes2ornt(header_ornt) - affine_ornt = axcodes2ornt(affine_ornt) - ornt = nib.orientations.ornt_transform(affine_ornt, header_ornt) - M = nib.orientations.inv_ornt_aff(ornt, self.header[Field.DIMENSIONS]) - affine = np.dot(M, affine) - - # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas `Tractogram` streamlines assume (0,0,0) is the - # center of the voxel. Thus, streamlines are shifted of half a voxel. - offset = np.eye(4) - offset[:-1, -1] += 0.5 - affine = np.dot(offset, affine) - - # Finally send the streamlines in mm space. - # voxel -> voxelmm - scale = np.eye(4) - scale[range(3), range(3)] *= self.header[Field.VOXEL_SIZES] - affine = np.dot(scale, affine) - - # The TRK format uses float32 as the data type for points. - affine = affine.astype(np.float32) + # Make sure streamlines are in rasmm then send them to voxmm. + tractogram = tractogram.to_world(lazy=True) + tractogram = tractogram.apply_affine(self._affine_rasmm_to_voxmm, + lazy=True) for t in tractogram: if any((len(d) != len(t.streamline) for d in t.data_for_points.values())): raise DataError("Missing scalars for some points!") - points = apply_affine(affine, - np.asarray(t.streamline, dtype=f4_dtype)) + points = np.asarray(t.streamline, dtype=f4_dtype) scalars = [np.asarray(t.data_for_points[k], dtype=f4_dtype) for k in data_for_points_keys] scalars = np.concatenate([np.ndarray((len(points), 0), @@ -747,8 +751,9 @@ def _read(): for name, slice_ in data_per_streamline_slice.items(): tractogram.data_per_streamline[name] = properties[:, slice_] - # Bring tractogram to RAS+ and mm space - tractogram.apply_affine(affine.astype(np.float32)) + # Bring tractogram to RAS+ and mm space. + tractogram = tractogram.apply_affine(affine.astype(np.float32)) + tractogram._affine_to_rasmm = np.eye(4) return cls(tractogram, header=hdr) From 96a221ff494ebab8c16e6a23ab9121dedb474b63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 30 Mar 2016 17:04:58 -0400 Subject: [PATCH 093/135] Fixed str method of TrkFile --- nibabel/streamlines/trk.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 34c55b9897..dd5be5bb15 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -12,7 +12,6 @@ import numpy as np import nibabel as nib -from nibabel.affines import apply_affine from nibabel.openers import Opener from nibabel.py3k import asbytes, asstr from nibabel.volumeutils import (native_code, swapped_code) @@ -791,8 +790,12 @@ def __str__(self): hdr_field = getattr(Field, attr) if hdr_field in vars: vars[attr] = vars[hdr_field] - vars['scalar_names'] = '\n'.join(map(asstr, vars['scalar_name'])) - vars['property_names'] = "\n".join(map(asstr, vars['property_name'])) + vars['scalar_names'] = '\n '.join([asstr(s) + for s in vars['scalar_name'] + if len(s) > 0]) + vars['property_names'] = "\n ".join([asstr(s) + for s in vars['property_name'] + if len(s) > 0]) return """\ MAGIC NUMBER: {MAGIC_NUMBER} v.{version} @@ -800,10 +803,10 @@ def __str__(self): voxel_sizes: {VOXEL_SIZES} orgin: {ORIGIN} nb_scalars: {NB_SCALARS_PER_POINT} -scalar_name:\n {scalar_names} +scalar_name:\n {scalar_names} nb_properties: {NB_PROPERTIES_PER_STREAMLINE} -property_name:\n {property_names} -vox_to_world: {VOXEL_TO_RASMM} +property_name:\n {property_names} +vox_to_world:\n{VOXEL_TO_RASMM} voxel_order: {VOXEL_ORDER} image_orientation_patient: {image_orientation_patient} pad1: {pad1} From a6e5026267f357534556597e359ff1ddce1619f7 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 4 Apr 2016 11:42:12 -0700 Subject: [PATCH 094/135] DOC: small docstring edits for ArraySequence Minor docstring edits --- nibabel/streamlines/array_sequence.py | 40 +++++++++++++++++++-------- 1 file changed, 29 insertions(+), 11 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index f7aa2ff88a..6fca60cc18 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -13,8 +13,8 @@ def is_array_sequence(obj): class ArraySequence(object): """ Sequence of ndarrays having variable first dimension sizes. - This is a container allowing to store multiple ndarrays where each ndarray - might have different first dimension size but a *common* size for the + This is a container that can store multiple ndarrays where each ndarray + might have a different first dimension size but a *common* size for the remaining dimensions. More generally, an instance of :class:`ArraySequence` of length $N$ is @@ -26,7 +26,8 @@ class ArraySequence(object): BUFFER_SIZE = 87382 * 4 # About 4 Mb if item shape is 3 (e.g. 3D points). def __init__(self, iterable=None): - """ + """ Initialize array sequence instance + Parameters ---------- iterable : None or iterable or :class:`ArraySequence`, optional @@ -100,6 +101,10 @@ def append(self, element): Element to append. The shape must match already inserted elements shape except for the first dimension. + Returns + ------- + None + Notes ----- If you need to add multiple elements you should consider @@ -130,10 +135,14 @@ def extend(self, elements): If :class:`ArraySequence` object, its data are simply appended to the data of this ArraySequence. + Returns + ------- + None + Notes ----- - The shape of the elements to be added must match the one of the - data of this :class:`ArraySequence` except for the first dimension. + The shape of the elements to be added must match the one of the data of + this :class:`ArraySequence` except for the first dimension. """ if not is_array_sequence(elements): self.extend(ArraySequence(elements)) @@ -162,10 +171,19 @@ def extend(self, elements): self._offsets = np.r_[self._offsets, offsets] def copy(self): - """ Creates a copy of this :class:`ArraySequence` object. """ - # We do not simply deepcopy this object since we might have a chance - # to use less memory. For example, if the array sequence being copied - # is the result of a slicing operation on a array sequence. + """ Creates a copy of this :class:`ArraySequence` object. + + Returns + ------- + seq_copy : :class:`ArraySequence` instance + Copy of `self`. + + Notes + ----- + We do not simply deepcopy this object because we have a chance to use + less memory. For example, if the array sequence being copied is the + result of a slicing operation on an array sequence. + """ seq = ArraySequence() total_lengths = np.sum(self._lengths) seq._data = np.empty((total_lengths,) + self._data.shape[1:], @@ -185,7 +203,7 @@ def copy(self): return seq def __getitem__(self, idx): - """ Gets sequence(s) through advanced indexing. + """ Get sequence(s) through standard or advanced numpy indexing. Parameters ---------- @@ -200,7 +218,7 @@ def __getitem__(self, idx): ------- ndarray or :class:`ArraySequence` If `idx` is an int, returns the selected sequence. - Otherwise, returns a :class:`ArraySequence` object which is view + Otherwise, returns a :class:`ArraySequence` object which is a view of the selected sequences. """ if isinstance(idx, (numbers.Integral, np.integer)): From fa7a3a2b7dec2fffefcc843977f66a452bb966f9 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 4 Apr 2016 11:42:33 -0700 Subject: [PATCH 095/135] RF: use self.__class__ for easier subclassing Instead of hard-coding the class name, use ``self.__class__``. If someone wants to subclass ArraySequence, they do not have to overwrite the methods hard-coding the class name. --- nibabel/streamlines/array_sequence.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 6fca60cc18..845d390edb 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -145,7 +145,7 @@ def extend(self, elements): this :class:`ArraySequence` except for the first dimension. """ if not is_array_sequence(elements): - self.extend(ArraySequence(elements)) + self.extend(self.__class__(elements)) return if len(elements) == 0: @@ -184,7 +184,7 @@ def copy(self): less memory. For example, if the array sequence being copied is the result of a slicing operation on an array sequence. """ - seq = ArraySequence() + seq = self.__class__() total_lengths = np.sum(self._lengths) seq._data = np.empty((total_lengths,) + self._data.shape[1:], dtype=self._data.dtype) @@ -226,7 +226,7 @@ def __getitem__(self, idx): return self._data[start:start + self._lengths[idx]] elif isinstance(idx, (slice, list)): - seq = ArraySequence() + seq = self.__class__() seq._data = self._data seq._offsets = self._offsets[idx] seq._lengths = self._lengths[idx] @@ -236,7 +236,7 @@ def __getitem__(self, idx): elif (isinstance(idx, np.ndarray) and (np.issubdtype(idx.dtype, np.integer) or np.issubdtype(idx.dtype, np.bool))): - seq = ArraySequence() + seq = self.__class__() seq._data = self._data seq._offsets = self._offsets[idx] seq._lengths = self._lengths[idx] From 7d0f80a50a74ee0c08ef7aa9a49f1a96d296c392 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 4 Apr 2016 12:22:34 -0700 Subject: [PATCH 096/135] RF: fuse two similar test functions Two test functions have much of the same supporting code. Maybe they can be merged. --- nibabel/streamlines/tests/test_streamlines.py | 45 ++++--------------- 1 file changed, 8 insertions(+), 37 deletions(-) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index cecc5c295e..e298c1116c 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -73,46 +73,12 @@ def setup(): DATA['data_per_point']) -def test_is_supported(): - # Emtpy file/string +def test_is_supported_detect_format(): + # Test is_supported and detect_format functions + # Empty file/string f = BytesIO() assert_false(nib.streamlines.is_supported(f)) assert_false(nib.streamlines.is_supported("")) - - # Valid file without extension - for tfile_cls in nib.streamlines.FORMATS.values(): - f = BytesIO() - f.write(tfile_cls.get_magic_number()) - f.seek(0, os.SEEK_SET) - assert_true(nib.streamlines.is_supported(f)) - - # Wrong extension but right magic number - for tfile_cls in nib.streamlines.FORMATS.values(): - with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: - f.write(tfile_cls.get_magic_number()) - f.seek(0, os.SEEK_SET) - assert_true(nib.streamlines.is_supported(f)) - - # Good extension but wrong magic number - for ext, tfile_cls in nib.streamlines.FORMATS.items(): - with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: - f.write(b"pass") - f.seek(0, os.SEEK_SET) - assert_false(nib.streamlines.is_supported(f)) - - # Wrong extension, string only - f = "my_tractogram.asd" - assert_false(nib.streamlines.is_supported(f)) - - # Good extension, string only - for ext, tfile_cls in nib.streamlines.FORMATS.items(): - f = "my_tractogram" + ext - assert_true(nib.streamlines.is_supported(f)) - - -def test_detect_format(): - # Emtpy file/string - f = BytesIO() assert_true(nib.streamlines.detect_format(f) is None) assert_true(nib.streamlines.detect_format("") is None) @@ -121,6 +87,7 @@ def test_detect_format(): f = BytesIO() f.write(tfile_cls.get_magic_number()) f.seek(0, os.SEEK_SET) + assert_true(nib.streamlines.is_supported(f)) assert_true(nib.streamlines.detect_format(f) is tfile_cls) # Wrong extension but right magic number @@ -128,6 +95,7 @@ def test_detect_format(): with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: f.write(tfile_cls.get_magic_number()) f.seek(0, os.SEEK_SET) + assert_true(nib.streamlines.is_supported(f)) assert_true(nib.streamlines.detect_format(f) is tfile_cls) # Good extension but wrong magic number @@ -135,15 +103,18 @@ def test_detect_format(): with tempfile.TemporaryFile(mode="w+b", suffix=ext) as f: f.write(b"pass") f.seek(0, os.SEEK_SET) + assert_false(nib.streamlines.is_supported(f)) assert_true(nib.streamlines.detect_format(f) is None) # Wrong extension, string only f = "my_tractogram.asd" + assert_false(nib.streamlines.is_supported(f)) assert_true(nib.streamlines.detect_format(f) is None) # Good extension, string only for ext, tfile_cls in nib.streamlines.FORMATS.items(): f = "my_tractogram" + ext + assert_true(nib.streamlines.is_supported(f)) assert_equal(nib.streamlines.detect_format(f), tfile_cls) # Extension should not be case-sensitive. From c86185b9e55f8e733f72a322fb1bcf355f7bdb40 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 11 Apr 2016 18:09:41 -0400 Subject: [PATCH 097/135] Use np.allclose to compare the affine_transformation with the identity matrix --- nibabel/streamlines/tractogram.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 11e2a931f1..9edc98fd80 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -545,7 +545,7 @@ def streamlines(self): streamlines_gen = (t.streamline for t in self._data()) # Check if we need to apply an affine. - if not np.all(self._affine_to_apply == np.eye(4)): + if not np.allclose(self._affine_to_apply, np.eye(4)): def _apply_affine(): for s in streamlines_gen: yield apply_affine(self._affine_to_apply, s) From 233af1aa7c305a96017f276ac4b29f1207b5a6bb Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 4 Apr 2016 14:03:40 -0700 Subject: [PATCH 098/135] WIP: thinking about API for data dict A suggested partial refactor of data dict. [skip ci] --- nibabel/streamlines/tractogram.py | 119 +++++++++++------------------- 1 file changed, 45 insertions(+), 74 deletions(-) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 9edc98fd80..8e722bdc29 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -10,60 +10,55 @@ def is_data_dict(obj): - """ Tells if obj is a :class:`DataDict`. """ + """ True if `obj` seems to implement the :class:`DataDict` API """ return hasattr(obj, 'store') def is_lazy_dict(obj): - """ Tells if obj is a :class:`LazyDict`. """ + """ True if `obj` seems to implement the :class:`LazyDict` API """ return is_data_dict(obj) and callable(obj.store.values()[0]) -class DataDict(collections.MutableMapping): - """ Dictionary that makes sure data are 2D array. +class SliceableDataDict(collections.MutableMapping): + """ Dictionary for which key access can do slicing on the values. - This container behaves like a standard dictionary but it makes sure its - elements are ndarrays. In addition, it makes sure the amount of data - contained in those ndarrays matches the number of streamlines of the - :class:`Tractogram` object provided at the instantiation of this - dictionary. + This container behaves like a standard dictionary but extends key access to + allow keys for key access to be indices slicing into the contained ndarray + values. """ - def __init__(self, tractogram, *args, **kwargs): - self.tractogram = tractogram + def __init__(self, *args, **kwargs): self.store = dict() - # Use update to set the keys. - if len(args) == 1: - if isinstance(args[0], DataDict): - self.update(**args[0]) - elif args[0] is None: - return - else: - self.update(dict(*args, **kwargs)) + if len(args) != 1: + self.update(dict(*args, **kwargs)) + return + if args[0] is None: + return + if isinstance(args[0], SliceableDataDict): + self.update(**args[0]) else: self.update(dict(*args, **kwargs)) def __getitem__(self, key): try: return self.store[key] - except KeyError: - pass # Maybe it is an integer. - except TypeError: - pass # Maybe it is an object for advanced indexing. - - # Try to interpret key as an index/slice in which case we - # perform (advanced) indexing on every element of the dictionnary. + except (KeyError, TypeError): + pass # Maybe it is an integer or a slicing object + + # Try to interpret key as an index/slice for every data element, in + # which case we perform (maybe advanced) indexing on every element of + # the dictionnary. + idx = key + new_dict = type(self)(None) try: - idx = key - new_dict = type(self)(None) for k, v in self.items(): new_dict[k] = v[idx] - - return new_dict except TypeError: pass + else: + return new_dict - # That means key was not an index/slice after all. + # Key was not a valid index/slice after all. return self.store[key] # Will raise the proper error. def __delitem__(self, key): @@ -76,15 +71,21 @@ def __len__(self): return len(self.store) -class DataPerStreamlineDict(DataDict): - """ Dictionary that makes sure data are 2D array. +class PerArrayDict(SliceableDataDict): + """ Dictionary for which key access can do slicing on the values. + + This container behaves like a standard dictionary but extends key access to + allow keys for key access to be indices slicing into the contained ndarray + values. The elements must also be ndarrays. - This container behaves like a standard dictionary but it makes sure its - elements are ndarrays. In addition, it makes sure the amount of data - contained in those ndarrays matches the number of streamlines of the - :class:`Tractogram` object provided at the instantiation of this + In addition, it makes sure the amount of data contained in those ndarrays + matches the number of streamlines given at the instantiation of this dictionary. """ + def __init__(self, n_elements, *args, **kwargs): + self.n_elements = n_elements + super(PerArrayDict, self).__init__(*args, **kwargs) + def __setitem__(self, key, value): value = np.asarray(list(value)) @@ -96,42 +97,15 @@ def __setitem__(self, key, value): raise ValueError("data_per_streamline must be a 2D array.") # We make sure there is the right amount of values - # (i.e. same as the number of streamlines in the tractogram). - if self.tractogram is not None and len(value) != len(self.tractogram): - msg = ("The number of values ({0}) should match the number of" - " streamlines ({1}).") - raise ValueError(msg.format(len(value), len(self.tractogram))) - - self.store[key] = value - - -class DataPerPointDict(DataDict): - """ Dictionary making sure data are :class:`ArraySequence` objects. - - This container behaves like a standard dictionary but it makes sure its - elements are :class:`ArraySequence` objects. In addition, it makes sure - the amount of data contained in those :class:`ArraySequence` objects - matches the the number of points of the :class:`Tractogram` object - provided at the instantiation of this dictionary. - """ - - def __setitem__(self, key, value): - value = ArraySequence(value) - - # We make sure we have the right amount of values (i.e. same as - # the total number of points of all streamlines in the tractogram). - if (self.tractogram is not None and - len(value._data) != len(self.tractogram.streamlines._data)): - msg = ("The number of values ({0}) should match the total" - " number of points of all streamlines ({1}).") - nb_streamlines_points = self.tractogram.streamlines._data - raise ValueError(msg.format(len(value._data), - len(nb_streamlines_points))) + if self.n_elements is not None and len(value) != self.n_elements: + msg = ("The number of values ({0}) should match n_elements " + "({1}).").format(len(value), self.n_elements) + raise ValueError(msg) self.store[key] = value -class LazyDict(DataDict): +class LazyDict(SliceableDataDict): """ Dictionary of generator functions. This container behaves like an dictionary but it makes sure its elements @@ -139,15 +113,13 @@ class LazyDict(DataDict): When getting the element associated to a given key, the element (i.e. a generator function) is first called before being returned. """ - def __init__(self, tractogram, *args, **kwargs): + def __init__(self, *args, **kwargs): if len(args) == 1 and isinstance(args[0], LazyDict): # Copy the generator functions. - self.tractogram = tractogram self.store = dict() self.update(**args[0].store) return - - super(LazyDict, self).__init__(tractogram, *args, **kwargs) + super(LazyDict, self).__init__(*args, **kwargs) def __getitem__(self, key): return self.store[key]() @@ -155,7 +127,6 @@ def __getitem__(self, key): def __setitem__(self, key, value): if value is not None and not callable(value): raise TypeError("`value` must be a generator function or None.") - self.store[key] = value From 1d01c0afba14c94028b12669ae8618cc60557424 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 12 Apr 2016 00:34:57 -0400 Subject: [PATCH 099/135] Refactored DataDict following @matthew-brett's advices. --- nibabel/streamlines/array_sequence.py | 5 + nibabel/streamlines/tests/test_tractogram.py | 160 +++++++++++++------ nibabel/streamlines/tractogram.py | 113 +++++++++---- 3 files changed, 191 insertions(+), 87 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 845d390edb..4c623cf942 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -92,6 +92,11 @@ def common_shape(self): """ Matching shape of the elements in this array sequence. """ return self._data.shape[1:] + @property + def nb_elements(self): + """ Total number of elements in this array sequence. """ + return self._data.shape[0] + def append(self, element): """ Appends `element` to this array sequence. diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 3652d14354..ece9dc9561 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -11,7 +11,7 @@ from .. import tractogram as module_tractogram from ..tractogram import TractogramItem, Tractogram, LazyTractogram -from ..tractogram import DataPerStreamlineDict, DataPerPointDict, LazyDict +from ..tractogram import PerArrayDict, PerArraySequenceDict, LazyDict DATA = {} @@ -126,53 +126,25 @@ def assert_tractogram_equal(t1, t2): t2.data_per_streamline, t2.data_per_point) -class TestTractogramItem(unittest.TestCase): +class TestPerArrayDict(unittest.TestCase): - def test_creating_tractogram_item(self): - rng = np.random.RandomState(42) - streamline = rng.rand(rng.randint(10, 50), 3) - colors = rng.rand(len(streamline), 3) - mean_curvature = 1.11 - mean_color = np.array([0, 1, 0], dtype="f4") - - data_for_streamline = {"mean_curvature": mean_curvature, - "mean_color": mean_color} - - data_for_points = {"colors": colors} - - # Create a tractogram item with a streamline, data. - t = TractogramItem(streamline, data_for_streamline, data_for_points) - assert_equal(len(t), len(streamline)) - assert_array_equal(t.streamline, streamline) - assert_array_equal(list(t), streamline) - assert_array_equal(t.data_for_streamline['mean_curvature'], - mean_curvature) - assert_array_equal(t.data_for_streamline['mean_color'], - mean_color) - assert_array_equal(t.data_for_points['colors'], - colors) - - -class TestTractogramDataDict(unittest.TestCase): - - def test_datadict_creation(self): - # Create a DataPerStreamlineDict object using another - # DataPerStreamlineDict object. + def test_per_array_dict_creation(self): + # Create a PerArrayDict object using another + # PerArrayDict object. + nb_streamlines = len(DATA['tractogram']) data_per_streamline = DATA['tractogram'].data_per_streamline - data_dict = DataPerStreamlineDict(DATA['tractogram'], - data_per_streamline) + data_dict = PerArrayDict(nb_streamlines, data_per_streamline) assert_equal(data_dict.keys(), data_per_streamline.keys()) for k in data_dict.keys(): assert_array_equal(data_dict[k], data_per_streamline[k]) del data_dict['mean_curvature'] assert_equal(len(data_dict), - len(DATA['tractogram'].data_per_streamline)-1) + len(data_per_streamline)-1) - # Create a DataPerStreamlineDict object using an existing dict object. - data_per_streamline = DATA['tractogram'].data_per_streamline.store - data_dict = DataPerStreamlineDict(DATA['tractogram'], - data_per_streamline) + # Create a PerArrayDict object using an existing dict object. + data_per_streamline = DATA['data_per_streamline'] + data_dict = PerArrayDict(nb_streamlines, data_per_streamline) assert_equal(data_dict.keys(), data_per_streamline.keys()) for k in data_dict.keys(): assert_array_equal(data_dict[k], data_per_streamline[k]) @@ -180,10 +152,9 @@ def test_datadict_creation(self): del data_dict['mean_curvature'] assert_equal(len(data_dict), len(data_per_streamline)-1) - # Create a DataPerStreamlineDict object using keyword arguments. - data_per_streamline = DATA['tractogram'].data_per_streamline.store - data_dict = DataPerStreamlineDict(DATA['tractogram'], - **data_per_streamline) + # Create a PerArrayDict object using keyword arguments. + data_per_streamline = DATA['data_per_streamline'] + data_dict = PerArrayDict(nb_streamlines, **data_per_streamline) assert_equal(data_dict.keys(), data_per_streamline.keys()) for k in data_dict.keys(): assert_array_equal(data_dict[k], data_per_streamline[k]) @@ -192,21 +163,77 @@ def test_datadict_creation(self): assert_equal(len(data_dict), len(data_per_streamline)-1) def test_getitem(self): - data_dict = DataPerPointDict(DATA['tractogram'], - DATA['data_per_point']) + sdict = PerArrayDict(len(DATA['tractogram']), + DATA['data_per_streamline']) + + assert_raises(KeyError, sdict.__getitem__, 'invalid') + + # Test slicing and advanced indexing. + for k, v in DATA['tractogram'].data_per_streamline.items(): + assert_true(k in sdict) + assert_arrays_equal(sdict[k], v) + assert_arrays_equal(sdict[::2][k], v[::2]) + assert_arrays_equal(sdict[::-1][k], v[::-1]) + assert_arrays_equal(sdict[-1][k], v[-1]) + assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]]) + + +class TestPerArraySequenceDict(unittest.TestCase): + + def test_per_array_sequence_dict_creation(self): + # Create a PerArraySequenceDict object using another + # PerArraySequenceDict object. + nb_elements = DATA['tractogram'].streamlines.nb_elements + data_per_point = DATA['tractogram'].data_per_point + data_dict = PerArraySequenceDict(nb_elements, data_per_point) + assert_equal(data_dict.keys(), data_per_point.keys()) + for k in data_dict.keys(): + assert_arrays_equal(data_dict[k], data_per_point[k]) + + del data_dict['fa'] + assert_equal(len(data_dict), + len(data_per_point)-1) + + # Create a PerArraySequenceDict object using an existing dict object. + data_per_point = DATA['data_per_point'] + data_dict = PerArraySequenceDict(nb_elements, data_per_point) + assert_equal(data_dict.keys(), data_per_point.keys()) + for k in data_dict.keys(): + assert_arrays_equal(data_dict[k], data_per_point[k]) + + del data_dict['fa'] + assert_equal(len(data_dict), len(data_per_point)-1) + + # Create a PerArraySequenceDict object using keyword arguments. + data_per_point = DATA['data_per_point'] + data_dict = PerArraySequenceDict(nb_elements, **data_per_point) + assert_equal(data_dict.keys(), data_per_point.keys()) + for k in data_dict.keys(): + assert_arrays_equal(data_dict[k], data_per_point[k]) + + del data_dict['fa'] + assert_equal(len(data_dict), len(data_per_point)-1) + + def test_getitem(self): + nb_elements = DATA['tractogram'].streamlines.nb_elements + sdict = PerArraySequenceDict(nb_elements, DATA['data_per_point']) - assert_true('fa' in data_dict) - assert_arrays_equal(data_dict['fa'], DATA['fa']) - assert_arrays_equal(data_dict[::2]['fa'], DATA['fa'][::2]) - assert_arrays_equal(data_dict[::-1]['fa'], DATA['fa'][::-1]) - assert_arrays_equal(data_dict[-1]['fa'], DATA['fa'][-1]) - assert_raises(KeyError, data_dict.__getitem__, 'invalid') + assert_raises(KeyError, sdict.__getitem__, 'invalid') + # Test slicing and advanced indexing. + for k, v in DATA['tractogram'].data_per_point.items(): + assert_true(k in sdict) + assert_arrays_equal(sdict[k], v) + assert_arrays_equal(sdict[::2][k], v[::2]) + assert_arrays_equal(sdict[::-1][k], v[::-1]) + assert_arrays_equal(sdict[-1][k], v[-1]) + assert_arrays_equal(sdict[[0, -1]][k], v[[0, -1]]) -class TestTractogramLazyDict(unittest.TestCase): + +class TestLazyDict(unittest.TestCase): def test_lazydict_creation(self): - data_dict = LazyDict(None, DATA['data_per_streamline_func']) + data_dict = LazyDict(DATA['data_per_streamline_func']) assert_equal(data_dict.keys(), DATA['data_per_streamline_func'].keys()) for k in data_dict.keys(): assert_array_equal(list(data_dict[k]), @@ -216,6 +243,33 @@ def test_lazydict_creation(self): len(DATA['data_per_streamline_func'])) +class TestTractogramItem(unittest.TestCase): + + def test_creating_tractogram_item(self): + rng = np.random.RandomState(42) + streamline = rng.rand(rng.randint(10, 50), 3) + colors = rng.rand(len(streamline), 3) + mean_curvature = 1.11 + mean_color = np.array([0, 1, 0], dtype="f4") + + data_for_streamline = {"mean_curvature": mean_curvature, + "mean_color": mean_color} + + data_for_points = {"colors": colors} + + # Create a tractogram item with a streamline, data. + t = TractogramItem(streamline, data_for_streamline, data_for_points) + assert_equal(len(t), len(streamline)) + assert_array_equal(t.streamline, streamline) + assert_array_equal(list(t), streamline) + assert_array_equal(t.data_for_streamline['mean_curvature'], + mean_curvature) + assert_array_equal(t.data_for_streamline['mean_color'], + mean_color) + assert_array_equal(t.data_for_points['colors'], + colors) + + class TestTractogram(unittest.TestCase): def test_tractogram_creation(self): diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 8e722bdc29..935e3ca651 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -28,16 +28,16 @@ class SliceableDataDict(collections.MutableMapping): """ def __init__(self, *args, **kwargs): self.store = dict() - # Use update to set the keys. - if len(args) != 1: - self.update(dict(*args, **kwargs)) - return - if args[0] is None: - return - if isinstance(args[0], SliceableDataDict): - self.update(**args[0]) - else: - self.update(dict(*args, **kwargs)) + # Use the 'update' method to set the keys. + if len(args) == 1: + if args[0] is None: + return + + if isinstance(args[0], SliceableDataDict): + self.update(**args[0]) + return + + self.update(dict(*args, **kwargs)) def __getitem__(self, key): try: @@ -53,7 +53,7 @@ def __getitem__(self, key): try: for k, v in self.items(): new_dict[k] = v[idx] - except TypeError: + except (TypeError, ValueError): pass else: return new_dict @@ -82,8 +82,8 @@ class PerArrayDict(SliceableDataDict): matches the number of streamlines given at the instantiation of this dictionary. """ - def __init__(self, n_elements, *args, **kwargs): - self.n_elements = n_elements + def __init__(self, nb_elements, *args, **kwargs): + self.nb_elements = nb_elements super(PerArrayDict, self).__init__(*args, **kwargs) def __setitem__(self, key, value): @@ -97,15 +97,43 @@ def __setitem__(self, key, value): raise ValueError("data_per_streamline must be a 2D array.") # We make sure there is the right amount of values - if self.n_elements is not None and len(value) != self.n_elements: + if self.nb_elements is not None and len(value) != self.nb_elements: msg = ("The number of values ({0}) should match n_elements " - "({1}).").format(len(value), self.n_elements) + "({1}).").format(len(value), self.nb_elements) raise ValueError(msg) self.store[key] = value -class LazyDict(SliceableDataDict): +class PerArraySequenceDict(SliceableDataDict): + """ Dictionary for which key access can do slicing on the values. + + This container behaves like a standard dictionary but extends key access to + allow keys for key access to be indices slicing into the contained ndarray + values. The elements must also be :class:`ArraySequence`. + + In addition, it makes sure the amount of data contained in those array + sequences matches the number of elements given at the instantiation + of this dictionary. + """ + def __init__(self, nb_elements, *args, **kwargs): + self.nb_elements = nb_elements + super(PerArraySequenceDict, self).__init__(*args, **kwargs) + + def __setitem__(self, key, value): + value = ArraySequence(value) + + # We make sure there is the right amount of data. + if (self.nb_elements is not None and + value.nb_elements != self.nb_elements): + msg = ("The number of values ({0}) should match " + "({1}).").format(value.nb_elements, self.nb_elements) + raise ValueError(msg) + + self.store[key] = value + + +class LazyDict(collections.MutableMapping): """ Dictionary of generator functions. This container behaves like an dictionary but it makes sure its elements @@ -114,21 +142,38 @@ class LazyDict(SliceableDataDict): generator function) is first called before being returned. """ def __init__(self, *args, **kwargs): - if len(args) == 1 and isinstance(args[0], LazyDict): - # Copy the generator functions. - self.store = dict() - self.update(**args[0].store) - return - super(LazyDict, self).__init__(*args, **kwargs) + self.store = dict() + # Use the 'update' method to set the keys. + if len(args) == 1: + if args[0] is None: + return + + if isinstance(args[0], LazyDict): + self.update(**args[0].store) # Copy the generator functions. + return + + if isinstance(args[0], SliceableDataDict): + self.update(**args[0]) + + self.update(dict(*args, **kwargs)) def __getitem__(self, key): return self.store[key]() def __setitem__(self, key, value): - if value is not None and not callable(value): + if value is not None and not callable(value): # TODO: why None? raise TypeError("`value` must be a generator function or None.") self.store[key] = value + def __delitem__(self, key): + del self.store[key] + + def __iter__(self): + return iter(self.store) + + def __len__(self): + return len(self.store) + class TractogramItem(object): """ Class containing information about one streamline. @@ -222,7 +267,7 @@ def __init__(self, streamlines=None, refers to the center of the voxel. By default, the streamlines are assumed to be already in *RAS+* and *mm* space. """ - self.streamlines = streamlines + self._set_streamlines(streamlines) self.data_per_streamline = data_per_streamline self.data_per_point = data_per_point self._affine_to_rasmm = affine_to_rasmm @@ -231,8 +276,7 @@ def __init__(self, streamlines=None, def streamlines(self): return self._streamlines - @streamlines.setter - def streamlines(self, value): + def _set_streamlines(self, value): self._streamlines = ArraySequence(value) @property @@ -241,7 +285,7 @@ def data_per_streamline(self): @data_per_streamline.setter def data_per_streamline(self, value): - self._data_per_streamline = DataPerStreamlineDict(self, value) + self._data_per_streamline = PerArrayDict(len(self.streamlines), value) @property def data_per_point(self): @@ -249,7 +293,8 @@ def data_per_point(self): @data_per_point.setter def data_per_point(self, value): - self._data_per_point = DataPerPointDict(self, value) + self._data_per_point = PerArraySequenceDict( + self.streamlines.nb_elements, value) def get_affine_to_rasmm(self): """ Returns the affine bringing this tractogram to RAS+mm. """ @@ -488,7 +533,8 @@ def create_from(cls, data_func): # Set data_per_streamline using data_func def _gen(key): - return lambda: (t.data_for_streamline[key] for t in data_func()) + return lambda: (t.data_for_streamline[key] + for t in data_func()) data_per_streamline_keys = first_item.data_for_streamline.keys() for k in data_per_streamline_keys: @@ -525,8 +571,7 @@ def _apply_affine(): return streamlines_gen - @streamlines.setter - def streamlines(self, value): + def _set_streamlines(self, value): if value is not None and not callable(value): raise TypeError("`streamlines` must be a generator function.") @@ -538,7 +583,7 @@ def data_per_streamline(self): @data_per_streamline.setter def data_per_streamline(self, value): - self._data_per_streamline = LazyDict(self, value) + self._data_per_streamline = LazyDict(value) @property def data_per_point(self): @@ -546,7 +591,7 @@ def data_per_point(self): @data_per_point.setter def data_per_point(self, value): - self._data_per_point = LazyDict(self, value) + self._data_per_point = LazyDict(value) @property def data(self): @@ -576,7 +621,7 @@ def _gen_data(): return _gen_data() def __getitem__(self, idx): - raise NotImplementedError('`LazyTractogram` does not support indexing.') + raise NotImplementedError('LazyTractogram does not support indexing.') def __iter__(self): count = 0 From 713ef7b64a3f1fb1b84dcdb03f070bdb6272940b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 12 Apr 2016 00:35:33 -0400 Subject: [PATCH 100/135] Improved repr of ArraySequence --- nibabel/streamlines/array_sequence.py | 12 +++++++++++- nibabel/streamlines/tests/test_array_sequence.py | 14 ++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 4c623cf942..83b19b3bd2 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -263,7 +263,17 @@ def __len__(self): return len(self._offsets) def __repr__(self): - return repr(list(self)) + if len(self) > np.get_printoptions()['threshold']: + # Show only the first and last edgeitems. + edgeitems = np.get_printoptions()['edgeitems'] + data = str(list(self[:edgeitems]))[:-1] + data += ", ..., " + data += str(list(self[-edgeitems:]))[1:] + else: + data = str(list(self)) + + return "{name}({data})".format(name=self.__class__.__name__, + data=data) def save(self, filename): """ Saves this :class:`ArraySequence` object to a .npz file. """ diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 8639ff58f8..39d7bd34d2 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -249,6 +249,20 @@ def test_arraysequence_repr(self): # Test that calling repr on a ArraySequence object is not falling. repr(SEQ_DATA['seq']) + # Test calling repr when the number of arrays is bigger dans Numpy's + # print option threshold. + nb_arrays = 50 + seq = ArraySequence(generate_data(nb_arrays, common_shape=(1,), + rng=SEQ_DATA['rng'])) + + bkp_threshold = np.get_printoptions()['threshold'] + np.set_printoptions(threshold=nb_arrays*2) + txt1 = repr(seq) + np.set_printoptions(threshold=nb_arrays//2) + txt2 = repr(seq) + assert_true(len(txt2) < len(txt1)) + np.set_printoptions(threshold=bkp_threshold) + def test_save_and_load_arraysequence(self): # Test saving and loading an empty ArraySequence. with tempfile.TemporaryFile(mode="w+b", suffix=".npz") as f: From 196761af342625ebad124a3ccd711cde5e9b1e6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 May 2016 19:24:21 -0400 Subject: [PATCH 101/135] Added function to create multiple ArraySequences from a generator. Also, added an option to specify the buffer size used when creating an ArraySequence from a generator. --- nibabel/streamlines/array_sequence.py | 59 ++++++++++++++----- .../streamlines/tests/test_array_sequence.py | 20 ++++--- 2 files changed, 57 insertions(+), 22 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 83b19b3bd2..cece00245c 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -10,6 +10,12 @@ def is_array_sequence(obj): return False +def is_ndarray_of_int_or_bool(obj): + return (isinstance(obj, np.ndarray) and + (np.issubdtype(obj.dtype, np.integer) or + np.issubdtype(obj.dtype, np.bool))) + + class ArraySequence(object): """ Sequence of ndarrays having variable first dimension sizes. @@ -23,9 +29,7 @@ class ArraySequence(object): same for every ndarray. """ - BUFFER_SIZE = 87382 * 4 # About 4 Mb if item shape is 3 (e.g. 3D points). - - def __init__(self, iterable=None): + def __init__(self, iterable=None, buffer_size=4): """ Initialize array sequence instance Parameters @@ -36,6 +40,8 @@ def __init__(self, iterable=None): from array-like objects yielded by the iterable. If :class:`ArraySequence`, create a view (no memory is allocated). For an actual copy use :meth:`.copy` instead. + buffer_size : float, optional + Size (in Mb) for memory allocation when `iterable` is a generator. """ # Create new empty `ArraySequence` object. self._is_view = False @@ -62,14 +68,23 @@ def __init__(self, iterable=None): for i, e in enumerate(iterable): e = np.asarray(e) if i == 0: - new_shape = (ArraySequence.BUFFER_SIZE,) + e.shape[1:] + try: + n_elements = np.sum([len(iterable[i]) + for i in range(len(iterable))]) + new_shape = (n_elements,) + e.shape[1:] + except TypeError: + # Can't get the number of elements in iterable. So, + # we use a memory buffer while building the ArraySequence. + n_rows_buffer = buffer_size*1024**2 // e.nbytes + new_shape = (n_rows_buffer,) + e.shape[1:] + self._data = np.empty(new_shape, dtype=e.dtype) end = offset + len(e) - if end >= len(self._data): + if end > len(self._data): # Resize needed, adding `len(e)` items plus some buffer. nb_points = len(self._data) - nb_points += len(e) + ArraySequence.BUFFER_SIZE + nb_points += len(e) + n_rows_buffer self._data.resize((nb_points,) + self.common_shape) offsets.append(offset) @@ -230,7 +245,7 @@ def __getitem__(self, idx): start = self._offsets[idx] return self._data[start:start + self._lengths[idx]] - elif isinstance(idx, (slice, list)): + elif isinstance(idx, (slice, list)) or is_ndarray_of_int_or_bool(idx): seq = self.__class__() seq._data = self._data seq._offsets = self._offsets[idx] @@ -238,16 +253,21 @@ def __getitem__(self, idx): seq._is_view = True return seq - elif (isinstance(idx, np.ndarray) and - (np.issubdtype(idx.dtype, np.integer) or - np.issubdtype(idx.dtype, np.bool))): + elif isinstance(idx, tuple): seq = self.__class__() - seq._data = self._data - seq._offsets = self._offsets[idx] - seq._lengths = self._lengths[idx] + seq._data = self._data.__getitem__((slice(None),) + idx[1:]) + seq._offsets = self._offsets[idx[0]] + seq._lengths = self._lengths[idx[0]] seq._is_view = True return seq + # for name, slice_ in data_per_point_slice.items(): + # seq = ArraySequence() + # seq._data = scalars._data[:, slice_] + # seq._offsets = scalars._offsets + # seq._lengths = scalars._lengths + # tractogram.data_per_point[name] = seq + raise TypeError("Index must be either an int, a slice, a list of int" " or a ndarray of bool! Not " + str(type(idx))) @@ -283,7 +303,7 @@ def save(self, filename): lengths=self._lengths) @classmethod - def from_filename(cls, filename): + def load(cls, filename): """ Loads a :class:`ArraySequence` object from a .npz file. """ content = np.load(filename) seq = cls() @@ -291,3 +311,14 @@ def from_filename(cls, filename): seq._offsets = content["offsets"] seq._lengths = content["lengths"] return seq + + +def create_arraysequences_from_generator(gen, n): + """ Creates :class:`ArraySequence` objects from a generator yielding tuples + """ + seqs = [ArraySequence() for _ in range(n)] + for data in gen: + for i, seq in enumerate(seqs): + seq.append(data[i]) + + return seqs diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 39d7bd34d2..b0f3d81708 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -61,7 +61,7 @@ def check_arr_seq(seq, arrays): def check_arr_seq_view(seq_view, seq): assert_true(seq_view._is_view) assert_true(seq_view is not seq) - assert_true(seq_view._data is seq._data) + assert_true(np.may_share_memory(seq_view._data, seq._data)) assert_true(seq_view._offsets is not seq._offsets) assert_true(seq_view._lengths is not seq._lengths) @@ -77,7 +77,7 @@ def test_creating_arraysequence_from_list(self): # List of ndarrays. N = 5 - for ndim in range(0, N+1): + for ndim in range(1, N+1): common_shape = tuple([SEQ_DATA['rng'].randint(1, 10) for _ in range(ndim-1)]) data = generate_data(nb_arrays=5, common_shape=common_shape, @@ -85,10 +85,9 @@ def test_creating_arraysequence_from_list(self): check_arr_seq(ArraySequence(data), data) # Force ArraySequence constructor to use buffering. - old_buffer_size = ArraySequence.BUFFER_SIZE - ArraySequence.BUFFER_SIZE = 1 - check_arr_seq(ArraySequence(SEQ_DATA['data']), SEQ_DATA['data']) - ArraySequence.BUFFER_SIZE = old_buffer_size + buffer_size = 1. / 1024**2 # 1 bytes + check_arr_seq(ArraySequence(iter(SEQ_DATA['data']), buffer_size), + SEQ_DATA['data']) def test_creating_arraysequence_from_generator(self): gen = (e for e in SEQ_DATA['data']) @@ -245,6 +244,11 @@ def test_arraysequence_getitem(self): # Test invalid indexing assert_raises(TypeError, SEQ_DATA['seq'].__getitem__, 'abc') + # Get specific columns. + seq_view = SEQ_DATA['seq'][:, 2] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data']]) + def test_arraysequence_repr(self): # Test that calling repr on a ArraySequence object is not falling. repr(SEQ_DATA['seq']) @@ -269,7 +273,7 @@ def test_save_and_load_arraysequence(self): seq = ArraySequence() seq.save(f) f.seek(0, os.SEEK_SET) - loaded_seq = ArraySequence.from_filename(f) + loaded_seq = ArraySequence.load(f) assert_array_equal(loaded_seq._data, seq._data) assert_array_equal(loaded_seq._offsets, seq._offsets) assert_array_equal(loaded_seq._lengths, seq._lengths) @@ -279,7 +283,7 @@ def test_save_and_load_arraysequence(self): seq = SEQ_DATA['seq'] seq.save(f) f.seek(0, os.SEEK_SET) - loaded_seq = ArraySequence.from_filename(f) + loaded_seq = ArraySequence.load(f) assert_array_equal(loaded_seq._data, seq._data) assert_array_equal(loaded_seq._offsets, seq._offsets) assert_array_equal(loaded_seq._lengths, seq._lengths) From 3b753dca68e4b69a321c6f3ae813af940fa1257d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 May 2016 19:28:32 -0400 Subject: [PATCH 102/135] Refactored TRK file. Removed unnecessary classmethod in TractogramFile --- nibabel/streamlines/tests/test_streamlines.py | 22 +- .../streamlines/tests/test_tractogram_file.py | 39 --- nibabel/streamlines/tests/test_trk.py | 12 +- nibabel/streamlines/tractogram_file.py | 15 - nibabel/streamlines/trk.py | 273 +++++------------- 5 files changed, 80 insertions(+), 281 deletions(-) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index e298c1116c..07b9fd8214 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -85,7 +85,7 @@ def test_is_supported_detect_format(): # Valid file without extension for tfile_cls in nib.streamlines.FORMATS.values(): f = BytesIO() - f.write(tfile_cls.get_magic_number()) + f.write(tfile_cls.MAGIC_NUMBER) f.seek(0, os.SEEK_SET) assert_true(nib.streamlines.is_supported(f)) assert_true(nib.streamlines.detect_format(f) is tfile_cls) @@ -93,7 +93,7 @@ def test_is_supported_detect_format(): # Wrong extension but right magic number for tfile_cls in nib.streamlines.FORMATS.values(): with tempfile.TemporaryFile(mode="w+b", suffix=".txt") as f: - f.write(tfile_cls.get_magic_number()) + f.write(tfile_cls.MAGIC_NUMBER) f.seek(0, os.SEEK_SET) assert_true(nib.streamlines.is_supported(f)) assert_true(nib.streamlines.detect_format(f) is tfile_cls) @@ -169,11 +169,12 @@ def test_load_complex_file(self): tractogram = Tractogram(DATA['streamlines']) - if tfile.support_data_per_point(): + if tfile.SUPPORTS_DATA_PER_POINT: tractogram.data_per_point = DATA['data_per_point'] - if tfile.support_data_per_streamline(): - tractogram.data_per_streamline = DATA['data_per_streamline'] + if tfile.SUPPORTS_DATA_PER_STREAMLINE: + data = DATA['data_per_streamline'] + tractogram.data_per_streamline = data assert_tractogram_equal(tfile.tractogram, tractogram) @@ -236,18 +237,19 @@ def test_save_complex_file(self): # If streamlines format does not support saving data # per point or data per streamline, a warning message # should be issued. - if not (cls.support_data_per_point() and - cls.support_data_per_streamline()): + if not (cls.SUPPORTS_DATA_PER_POINT and + cls.SUPPORTS_DATA_PER_STREAMLINE): assert_equal(len(w), 1) assert_true(issubclass(w[0].category, Warning)) tractogram = Tractogram(DATA['streamlines']) - if cls.support_data_per_point(): + if cls.SUPPORTS_DATA_PER_POINT: tractogram.data_per_point = DATA['data_per_point'] - if cls.support_data_per_streamline(): - tractogram.data_per_streamline = DATA['data_per_streamline'] + if cls.SUPPORTS_DATA_PER_STREAMLINE: + data = DATA['data_per_streamline'] + tractogram.data_per_streamline = data tfile = nib.streamlines.load(filename, lazy_load=False) assert_tractogram_equal(tfile.tractogram, tractogram) diff --git a/nibabel/streamlines/tests/test_tractogram_file.py b/nibabel/streamlines/tests/test_tractogram_file.py index 10bb481e64..b2995a124a 100644 --- a/nibabel/streamlines/tests/test_tractogram_file.py +++ b/nibabel/streamlines/tests/test_tractogram_file.py @@ -8,18 +8,6 @@ def test_subclassing_tractogram_file(): # Missing 'save' method class DummyTractogramFile(TractogramFile): - @classmethod - def get_magic_number(cls): - return False - - @classmethod - def support_data_per_point(cls): - return False - - @classmethod - def support_data_per_streamline(cls): - return False - @classmethod def is_correct_format(cls, fileobj): return False @@ -32,18 +20,6 @@ def load(cls, fileobj, lazy_load=True): # Missing 'load' method class DummyTractogramFile(TractogramFile): - @classmethod - def get_magic_number(cls): - return False - - @classmethod - def support_data_per_point(cls): - return False - - @classmethod - def support_data_per_streamline(cls): - return False - @classmethod def is_correct_format(cls, fileobj): return False @@ -55,26 +31,11 @@ def save(self, fileobj): def test_tractogram_file(): - assert_raises(NotImplementedError, TractogramFile.get_magic_number) assert_raises(NotImplementedError, TractogramFile.is_correct_format, "") - assert_raises(NotImplementedError, TractogramFile.support_data_per_point) - assert_raises(NotImplementedError, TractogramFile.support_data_per_streamline) assert_raises(NotImplementedError, TractogramFile.load, "") # Testing calling the 'save' method of `TractogramFile` object. class DummyTractogramFile(TractogramFile): - @classmethod - def get_magic_number(cls): - return False - - @classmethod - def support_data_per_point(cls): - return False - - @classmethod - def support_data_per_streamline(cls): - return False - @classmethod def is_correct_format(cls, fileobj): return False diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index b88c1d2634..132e7a6089 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -6,7 +6,7 @@ from nibabel.externals.six import BytesIO from nibabel.testing import suppress_warnings, clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal, check_iteration +from nibabel.testing import assert_arrays_equal from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal @@ -431,13 +431,3 @@ def test_write_scalars_and_properties_name_too_long(self): def test_str(self): trk = TrkFile.load(self.complex_trk_filename) str(trk) # Simply test it's not failing when called. - - def test_read_buffer_size(self): - tmp = TrkFile.READ_BUFFER_SIZE - TrkFile.READ_BUFFER_SIZE = 1 - - for lazy_load in [False, True]: - trk = TrkFile.load(self.complex_trk_filename, lazy_load=lazy_load) - assert_tractogram_equal(trk.tractogram, self.complex_tractogram) - - TrkFile.READ_BUFFER_SIZE = tmp diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index 4e77aba342..b2c7ef0018 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -64,21 +64,6 @@ def get_affine(self): """ Returns vox -> rasmm affine. """ return self.affine - @abstractclassmethod - def get_magic_number(cls): - """ Returns streamlines file's magic number. """ - raise NotImplementedError() - - @abstractclassmethod - def support_data_per_point(cls): - """ Tells if this format supports saving data per point. """ - raise NotImplementedError() - - @abstractclassmethod - def support_data_per_streamline(cls): - """ Tells if this format supports saving data per streamline. """ - raise NotImplementedError() - @abstractclassmethod def is_correct_format(cls, fileobj): """ Checks if the file has the right streamlines file format. diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index dd5be5bb15..cd34fcc47b 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -6,7 +6,6 @@ import os import struct import warnings -import itertools import string import numpy as np @@ -17,7 +16,7 @@ from nibabel.volumeutils import (native_code, swapped_code) from nibabel.orientations import (aff2axcodes, axcodes2ornt) -from .array_sequence import ArraySequence +from .array_sequence import create_arraysequences_from_generator from .tractogram_file import TractogramFile from .tractogram_file import DataError, HeaderError, HeaderWarning from .tractogram import TractogramItem, Tractogram, LazyTractogram @@ -88,6 +87,47 @@ header_2_dtype = np.dtype(header_2_dtd) +def get_affine_trackvis_to_rasmm(header): + # TRK's streamlines are in 'voxelmm' space, we will compute the + # affine matrix that will bring them back to RAS+ and mm space. + affine = np.eye(4) + + # The affine matrix found in the TRK header requires the points to + # be in the voxel space. + # voxelmm -> voxel + scale = np.eye(4) + scale[range(3), range(3)] /= header[Field.VOXEL_SIZES] + affine = np.dot(scale, affine) + + # TrackVis considers coordinate (0,0,0) to be the corner of the + # voxel whereas streamlines returned assume (0,0,0) to be the + # center of the voxel. Thus, streamlines are shifted of half a voxel. + offset = np.eye(4) + offset[:-1, -1] -= 0.5 + affine = np.dot(offset, affine) + + # If the voxel order implied by the affine does not match the voxel + # order in the TRK header, change the orientation. + # voxel (header) -> voxel (affine) + header_ornt = asstr(header[Field.VOXEL_ORDER]) + affine_ornt = "".join(aff2axcodes(header[Field.VOXEL_TO_RASMM])) + header_ornt = axcodes2ornt(header_ornt) + affine_ornt = axcodes2ornt(affine_ornt) + ornt = nib.orientations.ornt_transform(header_ornt, affine_ornt) + M = nib.orientations.inv_ornt_aff(ornt, header[Field.DIMENSIONS]) + affine = np.dot(M, affine) + + # Applied the affine found in the TRK header. + # voxel -> rasmm + voxel_to_rasmm = header[Field.VOXEL_TO_RASMM] + affine_voxmm_to_rasmm = np.dot(voxel_to_rasmm, affine) + return affine_voxmm_to_rasmm.astype(np.float32) + + +def get_affine_rasmm_to_trackvis(header): + return np.linalg.inv(get_affine_trackvis_to_rasmm(header)) + + class TrkReader(object): """ Convenience class to encapsulate TRK file format. @@ -112,14 +152,14 @@ def __init__(self, fileobj): self.fileobj = fileobj with Opener(self.fileobj) as f: - # Read header + # Read the header in one block. header_str = f.read(header_2_dtype.itemsize) header_rec = np.fromstring(string=header_str, dtype=header_2_dtype) # Check endianness - self.endianness = native_code + endianness = native_code if header_rec['hdr_size'] != TrkFile.HEADER_SIZE: - self.endianness = swapped_code + endianness = swapped_code # Swap byte order header_rec = header_rec.newbyteorder() @@ -132,12 +172,13 @@ def __init__(self, fileobj): header_rec = np.fromstring(string=header_str, dtype=header_1_dtype) elif header_rec['version'] == 2: - pass # Nothing more to do + pass # Nothing more to do. else: raise HeaderError('NiBabel only supports versions 1 and 2.') # Convert the first record of `header_rec` into a dictionnary self.header = dict(zip(header_rec.dtype.names, header_rec[0])) + self.header[Field.ENDIANNESS] = endianness # If vox_to_ras[3][3] is 0, it means the matrix is not recorded. if self.header[Field.VOXEL_TO_RASMM][3][3] == 0: @@ -167,8 +208,8 @@ def __init__(self, fileobj): self.offset_data = f.tell() def __iter__(self): - i4_dtype = np.dtype(self.endianness + "i4") - f4_dtype = np.dtype(self.endianness + "f4") + i4_dtype = np.dtype(self.header[Field.ENDIANNESS] + "i4") + f4_dtype = np.dtype(self.header[Field.ENDIANNESS] + "f4") with Opener(self.fileobj) as f: start_position = f.tell() @@ -188,9 +229,9 @@ def __iter__(self): if nb_streamlines == 0: nb_streamlines = np.inf - i = 0 + count = 0 nb_pts_dtype = i4_dtype.str[:-1] - while i < nb_streamlines: + while count < nb_streamlines: nb_pts_str = f.read(i4_dtype.itemsize) # Check if we reached EOF @@ -211,15 +252,15 @@ def __iter__(self): # Read properties properties = np.ndarray( - shape=(self.header[Field.NB_PROPERTIES_PER_STREAMLINE],), + shape=(nb_properties,), dtype=f4_dtype, buffer=f.read(properties_size)) yield points, scalars, properties - i += 1 + count += 1 # In case the 'count' field was not provided. - self.header[Field.NB_STREAMLINES] = i + self.header[Field.NB_STREAMLINES] = count # Set the file position where it was (in case it was already open). f.seek(start_position, os.SEEK_CUR) @@ -269,43 +310,6 @@ def __init__(self, fileobj, header): self.file.write(self.header.tostring()) - # `Tractogram` streamlines are in RAS+ and mm space, we will compute - # the affine matrix that will bring them back to 'voxelmm' as required - # by the TRK format. - affine = np.eye(4) - - # Applied the inverse of the affine found in the TRK header. - # rasmm -> voxel - affine = np.dot(np.linalg.inv(self.header[Field.VOXEL_TO_RASMM]), - affine) - - # If the voxel order implied by the affine does not match the voxel - # order in the TRK header, change the orientation. - # voxel (affine) -> voxel (header) - header_ornt = asstr(self.header[Field.VOXEL_ORDER]) - affine_ornt = "".join(aff2axcodes(self.header[Field.VOXEL_TO_RASMM])) - header_ornt = axcodes2ornt(header_ornt) - affine_ornt = axcodes2ornt(affine_ornt) - ornt = nib.orientations.ornt_transform(affine_ornt, header_ornt) - M = nib.orientations.inv_ornt_aff(ornt, self.header[Field.DIMENSIONS]) - affine = np.dot(M, affine) - - # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas `Tractogram` streamlines assume (0,0,0) is the - # center of the voxel. Thus, streamlines are shifted of half a voxel. - offset = np.eye(4) - offset[:-1, -1] += 0.5 - affine = np.dot(offset, affine) - - # Finally send the streamlines in mm space. - # voxel -> voxelmm - scale = np.eye(4) - scale[range(3), range(3)] *= self.header[Field.VOXEL_SIZES] - affine = np.dot(scale, affine) - - # The TRK format uses float32 as the data type for points. - self._affine_rasmm_to_voxmm = affine.astype(np.float32) - def write(self, tractogram): i4_dtype = np.dtype("= len(streamlines._data): - # Resize is needed (at least `len(pts)` items will be added). - streamlines._data.resize((len(streamlines._data) + len(pts) + - cls.READ_BUFFER_SIZE, - pts.shape[1])) - scalars._data.resize((len(scalars._data) + len(scals) + - cls.READ_BUFFER_SIZE, - scals.shape[1])) - - offsets.append(offset) - lengths.append(len(pts)) - streamlines._data[offset:offset + len(pts)] = pts - scalars._data[offset:offset + len(scals)] = scals - - offset += len(pts) - - if i >= len(properties): - properties.resize((len(properties) + cls.READ_BUFFER_SIZE, - props.shape[0])) - - properties[i] = props - - streamlines._offsets = np.asarray(offsets) - streamlines._lengths = np.asarray(lengths) - - # Clear unused memory. - streamlines._data.resize((offset, pts.shape[1])) - - if scals_shape[1] == 0: - # Because resizing an empty ndarray creates memory! - scalars._data = np.empty((offset, scals.shape[1])) - else: - scalars._data.resize((offset, scals.shape[1])) - - # Share offsets and lengths between streamlines and scalars. - scalars._offsets = streamlines._offsets - scalars._lengths = streamlines._lengths - - if props_shape[0] == 0: - # Because resizing an empty ndarray creates memory! - properties = np.empty((i + 1, props.shape[0])) - else: - properties.resize((i + 1, props.shape[0])) - - return streamlines, scalars, properties - @classmethod def load(cls, fileobj, lazy_load=False): """ Loads streamlines from a file-like object. @@ -640,40 +537,6 @@ def load(cls, fileobj, lazy_load=False): trk_reader = TrkReader(fileobj) hdr = trk_reader.header - # TRK's streamlines are in 'voxelmm' space, we will compute the - # affine matrix that will bring them back to RAS+ and mm space. - affine = np.eye(4) - - # The affine matrix found in the TRK header requires the points to be - # in the voxel space. - # voxelmm -> voxel - scale = np.eye(4) - scale[range(3), range(3)] /= hdr[Field.VOXEL_SIZES] - affine = np.dot(scale, affine) - - # TrackVis considers coordinate (0,0,0) to be the corner of the voxel - # whereas streamlines returned assume (0,0,0) to be the center of the - # voxel. Thus, streamlines are shifted of half a voxel. - offset = np.eye(4) - offset[:-1, -1] -= 0.5 - affine = np.dot(offset, affine) - - # If the voxel order implied by the affine does not match the voxel - # order in the TRK header, change the orientation. - # voxel (header) -> voxel (affine) - header_ornt = asstr(hdr[Field.VOXEL_ORDER]) - affine_ornt = "".join(aff2axcodes(hdr[Field.VOXEL_TO_RASMM])) - header_ornt = axcodes2ornt(header_ornt) - affine_ornt = axcodes2ornt(affine_ornt) - ornt = nib.orientations.ornt_transform(header_ornt, affine_ornt) - M = nib.orientations.inv_ornt_aff(ornt, - hdr[Field.DIMENSIONS]) - affine = np.dot(M, affine) - - # Applied the affine found in the TRK header. - # voxel -> rasmm - affine = np.dot(hdr[Field.VOXEL_TO_RASMM], affine) - # Find scalars and properties name data_per_point_slice = {} if hdr[Field.NB_SCALARS_PER_POINT] > 0: @@ -690,8 +553,8 @@ def load(cls, fileobj, lazy_load=False): nb_scalars = int(np.fromstring(scalar_name[-1], np.int8)) scalar_name = scalar_name.split('\x00')[0] - data_per_point_slice[scalar_name] = slice(cpt, - cpt + nb_scalars) + slice_obj = slice(cpt, cpt + nb_scalars) + data_per_point_slice[scalar_name] = slice_obj cpt += nb_scalars if cpt < hdr[Field.NB_SCALARS_PER_POINT]: @@ -736,22 +599,20 @@ def _read(): tractogram = LazyTractogram.create_from(_read) else: - arr_seqs = cls._create_arraysequence_from_generator(trk_reader) + arr_seqs = create_arraysequences_from_generator(trk_reader, n=3) streamlines, scalars, properties = arr_seqs + properties = np.asarray(properties) # Actually a 2d array. tractogram = Tractogram(streamlines) for name, slice_ in data_per_point_slice.items(): - seq = ArraySequence() - seq._data = scalars._data[:, slice_] - seq._offsets = scalars._offsets - seq._lengths = scalars._lengths - tractogram.data_per_point[name] = seq + tractogram.data_per_point[name] = scalars[:, slice_] for name, slice_ in data_per_streamline_slice.items(): tractogram.data_per_streamline[name] = properties[:, slice_] # Bring tractogram to RAS+ and mm space. - tractogram = tractogram.apply_affine(affine.astype(np.float32)) + affine_to_rasmm = get_affine_trackvis_to_rasmm(hdr) + tractogram = tractogram.apply_affine(affine_to_rasmm) tractogram._affine_to_rasmm = np.eye(4) return cls(tractogram, header=hdr) From 40ec4e57d098fb4714f8403874ab1093328e3077 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sun, 22 May 2016 23:52:38 -0400 Subject: [PATCH 103/135] Change how affine_to_rasmm is handled. Can now set affine_to_rasmm to None when streamlines' space is unknown --- nibabel/streamlines/array_sequence.py | 5 ++ nibabel/streamlines/header.py | 2 +- nibabel/streamlines/tests/test_streamlines.py | 25 ++++--- nibabel/streamlines/tests/test_tractogram.py | 60 +++++++++++----- nibabel/streamlines/tests/test_trk.py | 68 +++++++++++-------- nibabel/streamlines/tests/test_utils.py | 12 +--- nibabel/streamlines/tractogram.py | 63 +++++++++++------ nibabel/streamlines/trk.py | 6 +- nibabel/streamlines/utils.py | 11 +-- nibabel/testing/__init__.py | 2 + 10 files changed, 154 insertions(+), 100 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index cece00245c..35ebdfb45a 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -112,6 +112,11 @@ def nb_elements(self): """ Total number of elements in this array sequence. """ return self._data.shape[0] + @property + def data(self): + """ Elements in this array sequence. """ + return self._data + def append(self, element): """ Appends `element` to this array sequence. diff --git a/nibabel/streamlines/header.py b/nibabel/streamlines/header.py index 668d95ec78..c654b1234f 100644 --- a/nibabel/streamlines/header.py +++ b/nibabel/streamlines/header.py @@ -16,4 +16,4 @@ class Field: ORIGIN = "origin" VOXEL_TO_RASMM = "voxel_to_rasmm" VOXEL_ORDER = "voxel_order" - ENDIAN = "endian" + ENDIANNESS = "endianness" diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index 07b9fd8214..e814b3b620 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -66,11 +66,13 @@ def setup(): 'mean_torsion': mean_torsion, 'mean_colors': mean_colors} - DATA['empty_tractogram'] = Tractogram() - DATA['simple_tractogram'] = Tractogram(DATA['streamlines']) + DATA['empty_tractogram'] = Tractogram(affine_to_rasmm=np.eye(4)) + DATA['simple_tractogram'] = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) DATA['complex_tractogram'] = Tractogram(DATA['streamlines'], DATA['data_per_streamline'], - DATA['data_per_point']) + DATA['data_per_point'], + affine_to_rasmm=np.eye(4)) def test_is_supported_detect_format(): @@ -167,7 +169,8 @@ def test_load_complex_file(self): else: assert_true(type(tfile.tractogram), LazyTractogram) - tractogram = Tractogram(DATA['streamlines']) + tractogram = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) if tfile.SUPPORTS_DATA_PER_POINT: tractogram.data_per_point = DATA['data_per_point'] @@ -180,7 +183,8 @@ def test_load_complex_file(self): tractogram) def test_save_tractogram_file(self): - tractogram = Tractogram(DATA['streamlines']) + tractogram = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) trk_file = trk.TrkFile(tractogram) # No need for keyword arguments. @@ -204,7 +208,7 @@ def test_save_tractogram_file(self): assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_empty_file(self): - tractogram = Tractogram() + tractogram = Tractogram(affine_to_rasmm=np.eye(4)) for ext, cls in nib.streamlines.FORMATS.items(): with InTemporaryDirectory(): filename = 'streamlines' + ext @@ -213,7 +217,8 @@ def test_save_empty_file(self): assert_tractogram_equal(tfile.tractogram, tractogram) def test_save_simple_file(self): - tractogram = Tractogram(DATA['streamlines']) + tractogram = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) for ext, cls in nib.streamlines.FORMATS.items(): with InTemporaryDirectory(): filename = 'streamlines' + ext @@ -224,7 +229,8 @@ def test_save_simple_file(self): def test_save_complex_file(self): complex_tractogram = Tractogram(DATA['streamlines'], DATA['data_per_streamline'], - DATA['data_per_point']) + DATA['data_per_point'], + affine_to_rasmm=np.eye(4)) for ext, cls in nib.streamlines.FORMATS.items(): with InTemporaryDirectory(): @@ -242,7 +248,8 @@ def test_save_complex_file(self): assert_equal(len(w), 1) assert_true(issubclass(w[0].category, Warning)) - tractogram = Tractogram(DATA['streamlines']) + tractogram = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) if cls.SUPPORTS_DATA_PER_POINT: tractogram.data_per_point = DATA['data_per_point'] diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index ece9dc9561..40781f70a2 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -57,7 +57,8 @@ def setup(): DATA['simple_tractogram'] = Tractogram(DATA['streamlines']) DATA['tractogram'] = Tractogram(DATA['streamlines'], DATA['data_per_streamline'], - DATA['data_per_point']) + DATA['data_per_point'], + affine_to_rasmm=np.eye(4)) DATA['streamlines_func'] = lambda: (e for e in DATA['streamlines']) fa_func = lambda: (e for e in DATA['fa']) @@ -74,7 +75,8 @@ def setup(): DATA['lazy_tractogram'] = LazyTractogram(DATA['streamlines_func'], DATA['data_per_streamline_func'], - DATA['data_per_point_func']) + DATA['data_per_point_func'], + affine_to_rasmm=np.eye(4)) def check_tractogram_item(tractogram_item, @@ -276,6 +278,7 @@ def test_tractogram_creation(self): # Create an empty tractogram. tractogram = Tractogram() check_tractogram(tractogram) + assert_true(tractogram.affine_to_rasmm is None) # Create a tractogram with only streamlines tractogram = Tractogram(streamlines=DATA['streamlines']) @@ -284,7 +287,7 @@ def test_tractogram_creation(self): # Create a tractogram with a given affine_to_rasmm. affine = np.diag([1, 2, 3, 1]) tractogram = Tractogram(affine_to_rasmm=affine) - assert_array_equal(tractogram.get_affine_to_rasmm(), affine) + assert_array_equal(tractogram.affine_to_rasmm, affine) # Create a tractogram with streamlines and other data. tractogram = Tractogram(DATA['streamlines'], @@ -447,9 +450,10 @@ def test_tractogram_apply_affine(self): streamlines=[s*scaling for s in DATA['streamlines']], data_per_streamline=DATA['data_per_streamline'], data_per_point=DATA['data_per_point']) - assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), + assert_array_equal(transformed_tractogram.affine_to_rasmm, np.dot(np.eye(4), np.linalg.inv(affine))) - # Make sure streamlines of the original tractogram have not been modified. + # Make sure streamlines of the original tractogram have not been + # modified. assert_arrays_equal(tractogram.streamlines, DATA['streamlines']) # Apply the affine to the streamlines in-place. @@ -462,7 +466,7 @@ def test_tractogram_apply_affine(self): # Apply affine again and check the affine_to_rasmm. transformed_tractogram = tractogram.apply_affine(affine) - assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), + assert_array_equal(transformed_tractogram.affine_to_rasmm, np.dot(np.eye(4), np.dot(np.linalg.inv(affine), np.linalg.inv(affine)))) @@ -474,10 +478,16 @@ def test_tractogram_apply_affine(self): tractogram.apply_affine(affine) tractogram.apply_affine(np.linalg.inv(affine)) - assert_array_almost_equal(tractogram.get_affine_to_rasmm(), np.eye(4)) + assert_array_almost_equal(tractogram.affine_to_rasmm, np.eye(4)) for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): assert_array_almost_equal(s1, s2) + # Test removing affine_to_rasmm + tractogram = DATA['tractogram'].copy() + tractogram.affine_to_rasmm = None + tractogram.apply_affine(affine) + assert_true(tractogram.affine_to_rasmm is None) + def test_tractogram_to_world(self): tractogram = DATA['tractogram'].copy() affine = np.random.RandomState(1234).randn(4, 4) @@ -486,12 +496,12 @@ def test_tractogram_to_world(self): # Apply the affine to the streamlines, then bring them back # to world space in a lazy manner. transformed_tractogram = tractogram.apply_affine(affine) - assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), + assert_array_equal(transformed_tractogram.affine_to_rasmm, np.linalg.inv(affine)) tractogram_world = transformed_tractogram.to_world(lazy=True) assert_true(type(tractogram_world) is LazyTractogram) - assert_array_almost_equal(tractogram_world.get_affine_to_rasmm(), + assert_array_almost_equal(tractogram_world.affine_to_rasmm, np.eye(4)) for s1, s2 in zip(tractogram_world.streamlines, DATA['streamlines']): assert_array_almost_equal(s1, s2) @@ -499,17 +509,22 @@ def test_tractogram_to_world(self): # Bring them back streamlines to world space in a in-place manner. tractogram_world = transformed_tractogram.to_world() assert_true(tractogram_world is tractogram) - assert_array_almost_equal(tractogram.get_affine_to_rasmm(), np.eye(4)) + assert_array_almost_equal(tractogram.affine_to_rasmm, np.eye(4)) for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): assert_array_almost_equal(s1, s2) # Calling to_world twice should do nothing. tractogram_world2 = transformed_tractogram.to_world() assert_true(tractogram_world2 is tractogram) - assert_array_almost_equal(tractogram.get_affine_to_rasmm(), np.eye(4)) + assert_array_almost_equal(tractogram.affine_to_rasmm, np.eye(4)) for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): assert_array_almost_equal(s1, s2) + # Calling to_world when affine_to_rasmm is None should fail. + tractogram = DATA['tractogram'].copy() + tractogram.affine_to_rasmm = None + assert_raises(ValueError, tractogram.to_world) + class TestLazyTractogram(unittest.TestCase): @@ -534,6 +549,7 @@ def test_lazy_tractogram_creation(self): # Empty `LazyTractogram` tractogram = LazyTractogram() check_tractogram(tractogram) + assert_true(tractogram.affine_to_rasmm is None) # Create tractogram with streamlines and other data tractogram = LazyTractogram(DATA['streamlines_func'], @@ -627,9 +643,9 @@ def test_lazy_tractogram_apply_affine(self): transformed_tractogram = tractogram.apply_affine(affine) assert_true(transformed_tractogram is not tractogram) assert_array_equal(tractogram._affine_to_apply, np.eye(4)) - assert_array_equal(tractogram.get_affine_to_rasmm(), np.eye(4)) + assert_array_equal(tractogram.affine_to_rasmm, np.eye(4)) assert_array_equal(transformed_tractogram._affine_to_apply, affine) - assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), + assert_array_equal(transformed_tractogram.affine_to_rasmm, np.dot(np.eye(4), np.linalg.inv(affine))) check_tractogram(transformed_tractogram, streamlines=[s*scaling for s in DATA['streamlines']], @@ -640,10 +656,15 @@ def test_lazy_tractogram_apply_affine(self): transformed_tractogram = transformed_tractogram.apply_affine(affine) assert_array_equal(transformed_tractogram._affine_to_apply, np.dot(affine, affine)) - assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), + assert_array_equal(transformed_tractogram.affine_to_rasmm, np.dot(np.eye(4), np.dot(np.linalg.inv(affine), np.linalg.inv(affine)))) + # Calling to_world when affine_to_rasmm is None should fail. + tractogram = DATA['lazy_tractogram'].copy() + tractogram.affine_to_rasmm = None + assert_raises(ValueError, tractogram.to_world) + def test_tractogram_to_world(self): tractogram = DATA['lazy_tractogram'].copy() affine = np.random.RandomState(1234).randn(4, 4) @@ -652,22 +673,27 @@ def test_tractogram_to_world(self): # Apply the affine to the streamlines, then bring them back # to world space in a lazy manner. transformed_tractogram = tractogram.apply_affine(affine) - assert_array_equal(transformed_tractogram.get_affine_to_rasmm(), + assert_array_equal(transformed_tractogram.affine_to_rasmm, np.linalg.inv(affine)) tractogram_world = transformed_tractogram.to_world() assert_true(tractogram_world is not transformed_tractogram) - assert_array_almost_equal(tractogram_world.get_affine_to_rasmm(), + assert_array_almost_equal(tractogram_world.affine_to_rasmm, np.eye(4)) for s1, s2 in zip(tractogram_world.streamlines, DATA['streamlines']): assert_array_almost_equal(s1, s2) # Calling to_world twice should do nothing. tractogram_world = tractogram_world.to_world() - assert_array_almost_equal(tractogram_world.get_affine_to_rasmm(), np.eye(4)) + assert_array_almost_equal(tractogram_world.affine_to_rasmm, np.eye(4)) for s1, s2 in zip(tractogram_world.streamlines, DATA['streamlines']): assert_array_almost_equal(s1, s2) + # Calling to_world when affine_to_rasmm is None should fail. + tractogram = DATA['lazy_tractogram'].copy() + tractogram.affine_to_rasmm = None + assert_raises(ValueError, tractogram.to_world) + def test_lazy_tractogram_copy(self): # Create a copy of the lazy tractogram. tractogram = DATA['lazy_tractogram'].copy() diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 132e7a6089..7425947aad 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -2,6 +2,7 @@ import copy import unittest import numpy as np +from os.path import join as pjoin from nibabel.externals.six import BytesIO @@ -18,7 +19,7 @@ from ..trk import TrkFile from ..header import Field -DATA_PATH = os.path.join(os.path.dirname(__file__), 'data') +DATA_PATH = pjoin(os.path.dirname(__file__), 'data') def assert_header_equal(h1, h2): @@ -32,19 +33,19 @@ def assert_header_equal(h1, h2): class TestTRK(unittest.TestCase): def setUp(self): - self.empty_trk_filename = os.path.join(DATA_PATH, "empty.trk") + self.empty_trk_filename = pjoin(DATA_PATH, "empty.trk") # simple.trk contains only streamlines - self.simple_trk_filename = os.path.join(DATA_PATH, "simple.trk") + self.simple_trk_filename = pjoin(DATA_PATH, "simple.trk") # standard.trk contains only streamlines - self.standard_trk_filename = os.path.join(DATA_PATH, "standard.trk") + self.standard_trk_filename = pjoin(DATA_PATH, "standard.trk") # standard.LPS.trk contains only streamlines - self.standard_LPS_trk_filename = os.path.join(DATA_PATH, - "standard.LPS.trk") + self.standard_LPS_trk_filename = pjoin(DATA_PATH, + "standard.LPS.trk") # complex.trk contains streamlines, scalars and properties - self.complex_trk_filename = os.path.join(DATA_PATH, "complex.trk") - self.complex_trk_big_endian_filename = os.path.join(DATA_PATH, - "complex_big_endian.trk") + self.complex_trk_filename = pjoin(DATA_PATH, "complex.trk") + self.complex_trk_big_endian_filename = pjoin(DATA_PATH, + "complex_big_endian.trk") self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), np.arange(2*3, dtype="f4").reshape((2, 3)), @@ -81,11 +82,13 @@ def setUp(self): 'mean_torsion': self.mean_torsion, 'mean_colors': self.mean_colors} - self.empty_tractogram = Tractogram() - self.simple_tractogram = Tractogram(self.streamlines) + self.empty_tractogram = Tractogram(affine_to_rasmm=np.eye(4)) + self.simple_tractogram = Tractogram(self.streamlines, + affine_to_rasmm=np.eye(4)) self.complex_tractogram = Tractogram(self.streamlines, self.data_per_streamline, - self.data_per_point) + self.data_per_point, + affine_to_rasmm=np.eye(4)) def test_load_empty_file(self): for lazy_load in [False, True]: @@ -168,7 +171,7 @@ def test_tractogram_file_properties(self): assert_array_equal(trk.get_affine(), trk.header[Field.VOXEL_TO_RASMM]) def test_write_empty_file(self): - tractogram = Tractogram() + tractogram = Tractogram(affine_to_rasmm=np.eye(4)) trk_file = BytesIO() trk = TrkFile(tractogram) @@ -186,7 +189,8 @@ def test_write_empty_file(self): open(self.empty_trk_filename, 'rb').read()) def test_write_simple_file(self): - tractogram = Tractogram(self.streamlines) + tractogram = Tractogram(self.streamlines, + affine_to_rasmm=np.eye(4)) trk_file = BytesIO() trk = TrkFile(tractogram) @@ -206,7 +210,8 @@ def test_write_simple_file(self): def test_write_complex_file(self): # With scalars tractogram = Tractogram(self.streamlines, - data_per_point=self.data_per_point) + data_per_point=self.data_per_point, + affine_to_rasmm=np.eye(4)) trk_file = BytesIO() trk = TrkFile(tractogram) @@ -218,7 +223,8 @@ def test_write_complex_file(self): # With properties tractogram = Tractogram(self.streamlines, - data_per_streamline=self.data_per_streamline) + data_per_streamline=self.data_per_streamline, + affine_to_rasmm=np.eye(4)) trk = TrkFile(tractogram) trk_file = BytesIO() @@ -231,7 +237,8 @@ def test_write_complex_file(self): # With scalars and properties tractogram = Tractogram(self.streamlines, data_per_point=self.data_per_point, - data_per_streamline=self.data_per_streamline) + data_per_streamline=self.data_per_streamline, + affine_to_rasmm=np.eye(4)) trk_file = BytesIO() trk = TrkFile(tractogram) @@ -261,7 +268,6 @@ def test_load_write_file(self): assert_tractogram_equal(new_trk.tractogram, trk.tractogram) trk_file.seek(0, os.SEEK_SET) - #assert_equal(trk_file.read(), open(filename, 'rb').read()) def test_load_write_LPS_file(self): # Load the RAS and LPS version of the standard. @@ -287,7 +293,8 @@ def test_load_write_LPS_file(self): assert_equal(trk_file.read(), open(self.standard_LPS_trk_filename, 'rb').read()) - # Test writing a file where the header is missing the Field.VOXEL_ORDER. + # Test writing a file where the header is missing the + # Field.VOXEL_ORDER. trk_file = BytesIO() # For TRK file format, the default voxel order is LPS. @@ -313,7 +320,7 @@ def test_load_write_LPS_file(self): def test_write_optional_header_fields(self): # The TRK file format doesn't support additional header fields. # If provided, they will be ignored. - tractogram = Tractogram() + tractogram = Tractogram(affine_to_rasmm=np.eye(4)) trk_file = BytesIO() header = {'extra': 1234} @@ -331,7 +338,8 @@ def test_write_too_many_scalars_and_properties(self): data_per_point['#{0}'.format(i)] = self.fa tractogram = Tractogram(self.streamlines, - data_per_point=data_per_point) + data_per_point=data_per_point, + affine_to_rasmm=np.eye(4)) trk_file = BytesIO() trk = TrkFile(tractogram) @@ -345,7 +353,8 @@ def test_write_too_many_scalars_and_properties(self): data_per_point['#{0}'.format(i+1)] = self.fa tractogram = Tractogram(self.streamlines, - data_per_point=data_per_point) + data_per_point=data_per_point, + affine_to_rasmm=np.eye(4)) trk = TrkFile(tractogram) assert_raises(ValueError, trk.save, BytesIO()) @@ -356,7 +365,8 @@ def test_write_too_many_scalars_and_properties(self): data_per_streamline['#{0}'.format(i)] = self.mean_torsion tractogram = Tractogram(self.streamlines, - data_per_streamline=data_per_streamline) + data_per_streamline=data_per_streamline, + affine_to_rasmm=np.eye(4)) trk_file = BytesIO() trk = TrkFile(tractogram) @@ -384,7 +394,8 @@ def test_write_scalars_and_properties_name_too_long(self): for nb_chars in range(22): data_per_point = {'A'*nb_chars: self.colors} tractogram = Tractogram(self.streamlines, - data_per_point=data_per_point) + data_per_point=data_per_point, + affine_to_rasmm=np.eye(4)) trk = TrkFile(tractogram) if nb_chars > 18: @@ -394,7 +405,8 @@ def test_write_scalars_and_properties_name_too_long(self): data_per_point = {'A'*nb_chars: self.fa} tractogram = Tractogram(self.streamlines, - data_per_point=data_per_point) + data_per_point=data_per_point, + affine_to_rasmm=np.eye(4)) trk = TrkFile(tractogram) if nb_chars > 20: @@ -410,7 +422,8 @@ def test_write_scalars_and_properties_name_too_long(self): for nb_chars in range(22): data_per_streamline = {'A'*nb_chars: self.mean_colors} tractogram = Tractogram(self.streamlines, - data_per_streamline=data_per_streamline) + data_per_streamline=data_per_streamline, + affine_to_rasmm=np.eye(4)) trk = TrkFile(tractogram) if nb_chars > 18: @@ -420,7 +433,8 @@ def test_write_scalars_and_properties_name_too_long(self): data_per_streamline = {'A'*nb_chars: self.mean_torsion} tractogram = Tractogram(self.streamlines, - data_per_streamline=data_per_streamline) + data_per_streamline=data_per_streamline, + affine_to_rasmm=np.eye(4)) trk = TrkFile(tractogram) if nb_chars > 20: diff --git a/nibabel/streamlines/tests/test_utils.py b/nibabel/streamlines/tests/test_utils.py index 6c3bf096a6..939ee9bb9e 100644 --- a/nibabel/streamlines/tests/test_utils.py +++ b/nibabel/streamlines/tests/test_utils.py @@ -4,17 +4,9 @@ from nibabel.testing import data_path from numpy.testing import assert_array_equal -from nose.tools import assert_equal, assert_raises, assert_true +from nose.tools import assert_raises -from ..utils import pop, get_affine_from_reference - - -def test_peek(): - gen = (i for i in range(3)) - assert_equal(pop(gen), 0) - assert_equal(pop(gen), 1) - assert_equal(pop(gen), 2) - assert_true(pop(gen) is None) +from ..utils import get_affine_from_reference def test_get_affine_from_reference(): diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 935e3ca651..69eeba3493 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -240,7 +240,7 @@ class Tractogram(object): def __init__(self, streamlines=None, data_per_streamline=None, data_per_point=None, - affine_to_rasmm=np.eye(4)): + affine_to_rasmm=None): """ Parameters ---------- @@ -261,16 +261,16 @@ def __init__(self, streamlines=None, of ndarrays of shape ($N_t$, $M_i$) where $N_t$ is the number of points for a particular streamline $t$ and $M_i$ is the number scalar values to store for that particular information $i$. - affine_to_rasmm : ndarray of shape (4, 4), optional + affine_to_rasmm : ndarray of shape (4, 4) or None, optional Transformation matrix that brings the streamlines contained in this tractogram to *RAS+* and *mm* space where coordinate (0,0,0) refers to the center of the voxel. By default, the streamlines - are assumed to be already in *RAS+* and *mm* space. + are in an unknown space, i.e. affine_to_rasmm is None. """ self._set_streamlines(streamlines) self.data_per_streamline = data_per_streamline self.data_per_point = data_per_point - self._affine_to_rasmm = affine_to_rasmm + self.affine_to_rasmm = affine_to_rasmm @property def streamlines(self): @@ -296,9 +296,14 @@ def data_per_point(self, value): self._data_per_point = PerArraySequenceDict( self.streamlines.nb_elements, value) - def get_affine_to_rasmm(self): - """ Returns the affine bringing this tractogram to RAS+mm. """ - return self._affine_to_rasmm.copy() + @property + def affine_to_rasmm(self): + """ Affine bringing streamlines in this tractogram to RAS+mm. """ + return copy.deepcopy(self._affine_to_rasmm) + + @affine_to_rasmm.setter + def affine_to_rasmm(self, value): + self._affine_to_rasmm = value def __iter__(self): for i in range(len(self.streamlines)): @@ -358,14 +363,15 @@ def apply_affine(self, affine, lazy=False): return self BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. - for start in range(0, len(self.streamlines._data), BUFFER_SIZE): + for start in range(0, len(self.streamlines.data), BUFFER_SIZE): end = start + BUFFER_SIZE pts = self.streamlines._data[start:end] - self.streamlines._data[start:end] = apply_affine(affine, pts) + self.streamlines.data[start:end] = apply_affine(affine, pts) - # Update the affine that brings back the streamlines to RASmm. - self._affine_to_rasmm = np.dot(self._affine_to_rasmm, - np.linalg.inv(affine)) + if self.affine_to_rasmm is not None: + # Update the affine that brings back the streamlines to RASmm. + self.affine_to_rasmm = np.dot(self.affine_to_rasmm, + np.linalg.inv(affine)) return self @@ -389,7 +395,12 @@ def to_world(self, lazy=False): object, otherwise it returns a reference to this :class:`Tractogram` object with updated streamlines. """ - return self.apply_affine(self._affine_to_rasmm, lazy=lazy) + if self.affine_to_rasmm is None: + msg = ("Streamlines are in a unknown space. This error can be" + " avoided by setting the 'affine_to_rasmm' property.") + raise ValueError(msg) + + return self.apply_affine(self.affine_to_rasmm, lazy=lazy) class LazyTractogram(Tractogram): @@ -434,7 +445,7 @@ class LazyTractogram(Tractogram): def __init__(self, streamlines=None, data_per_streamline=None, data_per_point=None, - affine_to_rasmm=np.eye(4)): + affine_to_rasmm=None): """ Parameters ---------- @@ -457,10 +468,11 @@ def __init__(self, streamlines=None, ($N_t$, $M_i$) where $N_t$ is the number of points for a particular streamline $t$ and $M_i$ is the number scalar values to store for that particular information $i$. - affine_to_rasmm : ndarray of shape (4, 4) + affine_to_rasmm : ndarray of shape (4, 4) or None, optional Transformation matrix that brings the streamlines contained in this tractogram to *RAS+* and *mm* space where coordinate (0,0,0) - refers to the center of the voxel. + refers to the center of the voxel. By default, the streamlines + are in an unknown space, i.e. affine_to_rasmm is None. """ super(LazyTractogram, self).__init__(streamlines, data_per_streamline, @@ -501,7 +513,7 @@ def _gen(key): lazy_tractogram._data_per_point[k] = _gen(k) lazy_tractogram._nb_streamlines = len(tractogram) - lazy_tractogram._affine_to_rasmm = tractogram.get_affine_to_rasmm() + lazy_tractogram.affine_to_rasmm = tractogram.affine_to_rasmm return lazy_tractogram @classmethod @@ -650,7 +662,8 @@ def copy(self): """ Returns a copy of this :class:`LazyTractogram` object. """ tractogram = LazyTractogram(self._streamlines, self._data_per_streamline, - self._data_per_point) + self._data_per_point, + self.affine_to_rasmm) tractogram._nb_streamlines = self._nb_streamlines tractogram._data = self._data tractogram._affine_to_apply = self._affine_to_apply.copy() @@ -685,9 +698,10 @@ def apply_affine(self, affine, lazy=True): # Update the affine that will be applied when returning streamlines. tractogram._affine_to_apply = np.dot(affine, self._affine_to_apply) - # Update the affine that brings back the streamlines to RASmm. - tractogram._affine_to_rasmm = np.dot(self._affine_to_rasmm, - np.linalg.inv(affine)) + if tractogram.affine_to_rasmm is not None: + # Update the affine that brings back the streamlines to RASmm. + tractogram.affine_to_rasmm = np.dot(self.affine_to_rasmm, + np.linalg.inv(affine)) return tractogram def to_world(self, lazy=True): @@ -708,4 +722,9 @@ def to_world(self, lazy=True): A copy of this :class:`LazyTractogram` instance but with a transformation to be applied on the streamlines. """ - return self.apply_affine(self._affine_to_rasmm, lazy=lazy) + if self.affine_to_rasmm is None: + msg = ("Streamlines are in a unknown space. This error can be" + " avoided by setting the 'affine_to_rasmm' property.") + raise ValueError(msg) + + return self.apply_affine(self.affine_to_rasmm, lazy=lazy) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index cd34fcc47b..90d4f08305 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -610,10 +610,8 @@ def _read(): for name, slice_ in data_per_streamline_slice.items(): tractogram.data_per_streamline[name] = properties[:, slice_] - # Bring tractogram to RAS+ and mm space. - affine_to_rasmm = get_affine_trackvis_to_rasmm(hdr) - tractogram = tractogram.apply_affine(affine_to_rasmm) - tractogram._affine_to_rasmm = np.eye(4) + tractogram.affine_to_rasmm = get_affine_trackvis_to_rasmm(hdr) + tractogram = tractogram.to_world() return cls(tractogram, header=hdr) diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 254e9b3442..3a5c648cc9 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -1,8 +1,5 @@ import numpy as np import nibabel -import itertools - -from nibabel.spatialimages import SpatialImage def get_affine_from_reference(ref): @@ -27,14 +24,8 @@ def get_affine_from_reference(ref): raise ValueError(msg) return ref - elif isinstance(ref, SpatialImage): + elif hasattr(ref, 'affine'): return ref.affine # Assume `ref` is the name of a neuroimaging file. return nibabel.load(ref).affine - - -def pop(iterable): - """ Returns the next item from the iterable else None. """ - value = list(itertools.islice(iterable, 1)) - return value[0] if len(value) > 0 else None diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index fa3454d473..6868b18269 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -77,7 +77,9 @@ def check_iteration(iterable): def assert_arrays_equal(arrays1, arrays2): + """ Checks that two iterables yielding arrays are equals. """ for arr1, arr2 in zip_longest(arrays1, arrays2, fillvalue=None): + assert_false(arr1 is None or arr2 is None) assert_array_equal(arr1, arr2) From 9b08496aa3f9984a007c7fc2e2d95201316e7424 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 23 May 2016 16:16:13 -0400 Subject: [PATCH 104/135] Refactored trk.py --- nibabel/streamlines/tractogram.py | 4 +- nibabel/streamlines/trk.py | 107 ++++++++++++++++-------------- 2 files changed, 61 insertions(+), 50 deletions(-) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 69eeba3493..ff765d874b 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -223,13 +223,13 @@ class Tractogram(object): Sequence of $T$ streamlines. Each streamline is an ndarray of shape ($N_t$, 3) where $N_t$ is the number of points of streamline $t$. - data_per_streamline : dict of 2D arrays + data_per_streamline : :class:`PerArrayDict` object Dictionary where the items are (str, 2D array). Each key represents an information $i$ to be kept along side every streamline, and its associated value is a 2D array of shape ($T$, $P_i$) where $T$ is the number of streamlines and $P_i$ is the number scalar values to store for that particular information $i$. - data_per_point : dict of :class:`ArraySequence` objects + data_per_point : :class:`PerArraySequenceDict` object Dictionary where the items are (str, :class:`ArraySequence`). Each key represents an information $i$ to be kept along side every point of every streamline, and its associated value is an iterable diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 90d4f08305..9f34ef8d9e 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -128,6 +128,47 @@ def get_affine_rasmm_to_trackvis(header): return np.linalg.inv(get_affine_trackvis_to_rasmm(header)) +def encode_value_in_name(value, name, max_name_len=20): + """ Encodes a value in the last two bytes of a string. + + If `value` is one, then there is no encoding and the last two bytes + are left untouched. This function also verify that the length of name is + less than `max_name_len`. + + Parameters + ---------- + value : int + Integer value to encode. + name : str + Name in which the last two bytes will serve to encode `value`. + max_name_len : int, optional + Maximum length name can have. + + Returns + ------- + encoded_name : str + Name containing the encoded value. + """ + + if len(name) > max_name_len: + msg = ("Data information named '{0}' is too long" + " (max {1} characters.)").format(name, max_name_len) + raise ValueError(msg) + elif len(name) > max_name_len-2 and value > 1: + msg = ("Data information named '{0}' is too long (need to be less" + " than {1} characters when storing more than one value" + " for a given data information." + ).format(name, max_name_len-2) + raise ValueError(msg) + + if value > 1: + # Use the last two bytes of `name` to store `value`. + name = (asbytes(name[:18].ljust(18, '\x00')) + b'\x00' + + np.array(value, dtype=np.int8).tostring()) + + return name + + class TrkReader(object): """ Convenience class to encapsulate TRK file format. @@ -326,8 +367,7 @@ def write(self, tractogram): self.file.write(self.header.tostring()) return - # Update the 'property_name' field using 'data_per_streamline' of the - # tractogram. + # Update field 'property_name' using 'tractogram.data_per_streamline'. data_for_streamline = first_item.data_for_streamline if len(data_for_streamline) > MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: msg = ("Can only store {0} named data_per_streamline (also known" @@ -336,31 +376,16 @@ def write(self, tractogram): raise ValueError(msg) data_for_streamline_keys = sorted(data_for_streamline.keys()) - self.header['property_name'] = np.zeros( - MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, - dtype='S20') - for i, k in enumerate(data_for_streamline_keys): - nb_values = data_for_streamline[k].shape[0] - - if len(k) > 20: - raise ValueError(("Property name '{0}' is too long (max 20" - "characters.)").format(k)) - elif len(k) > 18 and nb_values > 1: - raise ValueError(("Property name '{0}' is too long (need to be" - " less than 18 characters when storing more" - " than one value").format(k)) - - property_name = k - if nb_values > 1: - # Use the last to bytes of the name to store the nb of values - # associated to this data_for_streamline. - property_name = (asbytes(k[:18].ljust(18, '\x00')) + b'\x00' + - np.array(nb_values, dtype=np.int8).tostring()) - - self.header['property_name'][i] = property_name - - # Update the 'scalar_name' field using 'data_per_point' of the - # tractogram. + property_name = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, + dtype='S20') + for i, name in enumerate(data_for_streamline_keys): + # Use the last to bytes of the name to store the number of values + # associated to this data_for_streamline. + nb_values = data_for_streamline[name].shape[-1] + property_name[i] = encode_value_in_name(nb_values, name) + self.header['property_name'][:] = property_name + + # Update field 'scalar_name' using 'tractogram.data_per_point'. data_for_points = first_item.data_for_points if len(data_for_points) > MAX_NB_NAMED_SCALARS_PER_POINT: raise ValueError(("Can only store {0} named data_per_point (also" @@ -368,27 +393,13 @@ def write(self, tractogram): ).format(MAX_NB_NAMED_SCALARS_PER_POINT)) data_for_points_keys = sorted(data_for_points.keys()) - self.header['scalar_name'] = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, - dtype='S20') - for i, k in enumerate(data_for_points_keys): - nb_values = data_for_points[k].shape[1] - - if len(k) > 20: - raise ValueError(("Scalar name '{0}' is too long (max 18" - " characters.)").format(k)) - elif len(k) > 18 and nb_values > 1: - raise ValueError(("Scalar name '{0}' is too long (need to be" - " less than 18 characters when storing more" - " than one value").format(k)) - - scalar_name = k - if nb_values > 1: - # Use the last to bytes of the name to store the nb of values - # associated to this data_for_streamline. - scalar_name = (asbytes(k[:18].ljust(18, '\x00')) + b'\x00' + - np.array(nb_values, dtype=np.int8).tostring()) - - self.header['scalar_name'][i] = scalar_name + scalar_name = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') + for i, name in enumerate(data_for_points_keys): + # Use the last two bytes of the name to store the number of values + # associated to this data_for_streamline. + nb_values = data_for_points[name].shape[-1] + scalar_name[i] = encode_value_in_name(nb_values, name) + self.header['scalar_name'][:] = scalar_name # Make sure streamlines are in rasmm then send them to voxmm. tractogram = tractogram.to_world(lazy=True) From 9e2a39c7d11855da41c58e12ee73e90b90608928 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 23 May 2016 16:29:24 -0400 Subject: [PATCH 105/135] Be numpy 1.11 compliant --- nibabel/streamlines/array_sequence.py | 2 +- nibabel/streamlines/tractogram.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 35ebdfb45a..d60ab4908a 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -75,7 +75,7 @@ def __init__(self, iterable=None, buffer_size=4): except TypeError: # Can't get the number of elements in iterable. So, # we use a memory buffer while building the ArraySequence. - n_rows_buffer = buffer_size*1024**2 // e.nbytes + n_rows_buffer = int(buffer_size*1024**2 // e.nbytes) new_shape = (n_rows_buffer,) + e.shape[1:] self._data = np.empty(new_shape, dtype=e.dtype) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index ff765d874b..a977238b76 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -42,7 +42,7 @@ def __init__(self, *args, **kwargs): def __getitem__(self, key): try: return self.store[key] - except (KeyError, TypeError): + except (KeyError, TypeError, IndexError): pass # Maybe it is an integer or a slicing object # Try to interpret key as an index/slice for every data element, in @@ -53,7 +53,7 @@ def __getitem__(self, key): try: for k, v in self.items(): new_dict[k] = v[idx] - except (TypeError, ValueError): + except (TypeError, ValueError, IndexError): pass else: return new_dict From 0b4a3756c014d62e35302d623e61319fe1d79fe6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 23 May 2016 16:48:21 -0400 Subject: [PATCH 106/135] PEP8 --- nibabel/streamlines/array_sequence.py | 2 +- nibabel/streamlines/trk.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index d60ab4908a..adfd98439e 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -75,7 +75,7 @@ def __init__(self, iterable=None, buffer_size=4): except TypeError: # Can't get the number of elements in iterable. So, # we use a memory buffer while building the ArraySequence. - n_rows_buffer = int(buffer_size*1024**2 // e.nbytes) + n_rows_buffer = int(buffer_size * 1024**2 // e.nbytes) new_shape = (n_rows_buffer,) + e.shape[1:] self._data = np.empty(new_shape, dtype=e.dtype) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 9f34ef8d9e..f454b9f8d6 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -154,11 +154,11 @@ def encode_value_in_name(value, name, max_name_len=20): msg = ("Data information named '{0}' is too long" " (max {1} characters.)").format(name, max_name_len) raise ValueError(msg) - elif len(name) > max_name_len-2 and value > 1: + elif len(name) > max_name_len - 2 and value > 1: msg = ("Data information named '{0}' is too long (need to be less" " than {1} characters when storing more than one value" " for a given data information." - ).format(name, max_name_len-2) + ).format(name, max_name_len - 2) raise ValueError(msg) if value > 1: From 8348d1e9a12734558da65a1ac85157d74afe2b8d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 24 May 2016 09:05:57 -0400 Subject: [PATCH 107/135] Addressed @matthew-brett's comments --- nibabel/streamlines/tests/test_tractogram.py | 8 ++--- nibabel/streamlines/trk.py | 35 +++++++++++--------- nibabel/streamlines/utils.py | 10 +++--- nibabel/testing/__init__.py | 13 +------- 4 files changed, 30 insertions(+), 36 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 40781f70a2..68d8543d13 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -3,7 +3,7 @@ import numpy as np import warnings -from nibabel.testing import assert_arrays_equal, check_iteration +from nibabel.testing import assert_arrays_equal from nibabel.testing import clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal, assert_array_almost_equal @@ -110,7 +110,7 @@ def check_tractogram(tractogram, streamlines = list(streamlines) assert_equal(len(tractogram), len(streamlines)) assert_arrays_equal(tractogram.streamlines, streamlines) - assert_true(check_iteration(tractogram)) + [t for t in tractogram] # Force iteration through tractogram. assert_equal(len(tractogram.data_per_streamline), len(data_per_streamline)) for key in data_per_streamline.keys(): @@ -556,7 +556,7 @@ def test_lazy_tractogram_creation(self): DATA['data_per_streamline_func'], DATA['data_per_point_func']) - assert_true(check_iteration(tractogram)) + [t for t in tractogram] # Force iteration through tractogram. assert_equal(len(tractogram), len(DATA['streamlines'])) # Generator functions get re-called and creates new iterators. @@ -627,7 +627,7 @@ def test_lazy_tractogram_len(self): tractogram = LazyTractogram(DATA['streamlines_func']) assert_true(tractogram._nb_streamlines is None) - check_iteration(tractogram) # Force iteration through tractogram. + [t for t in tractogram] # Force iteration through tractogram. assert_equal(tractogram._nb_streamlines, len(DATA['streamlines'])) # This should *not* produce a warning. assert_equal(len(tractogram), len(DATA['streamlines'])) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index f454b9f8d6..c43bad02de 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -100,8 +100,8 @@ def get_affine_trackvis_to_rasmm(header): affine = np.dot(scale, affine) # TrackVis considers coordinate (0,0,0) to be the corner of the - # voxel whereas streamlines returned assume (0,0,0) to be the - # center of the voxel. Thus, streamlines are shifted of half a voxel. + # voxel whereas streamlines returned assumes (0,0,0) to be the + # center of the voxel. Thus, streamlines are shifted by half a voxel. offset = np.eye(4) offset[:-1, -1] -= 0.5 affine = np.dot(offset, affine) @@ -132,21 +132,24 @@ def encode_value_in_name(value, name, max_name_len=20): """ Encodes a value in the last two bytes of a string. If `value` is one, then there is no encoding and the last two bytes - are left untouched. This function also verify that the length of name is - less than `max_name_len`. + are left untouched. Otherwise, the byte before the last will be + set to \x00 and the last byte will correspond to the value. + + This function also verifies that the length of name is less + than `max_name_len`. Parameters ---------- - value : int - Integer value to encode. - name : str + value : byte + Integer value between 0 and 255 to encode. + name : bytes Name in which the last two bytes will serve to encode `value`. max_name_len : int, optional Maximum length name can have. Returns ------- - encoded_name : str + encoded_name : bytes Name containing the encoded value. """ @@ -161,9 +164,10 @@ def encode_value_in_name(value, name, max_name_len=20): ).format(name, max_name_len - 2) raise ValueError(msg) + name = name.ljust(max_name_len, '\x00') if value > 1: # Use the last two bytes of `name` to store `value`. - name = (asbytes(name[:18].ljust(18, '\x00')) + b'\x00' + + name = (asbytes(name[:max_name_len - 2]) + b'\x00' + np.array(value, dtype=np.int8).tostring()) return name @@ -183,10 +187,10 @@ class TrkReader(object): ---- TrackVis (so its file format: TRK) considers the streamline coordinate (0,0,0) to be in the corner of the voxel whereas NiBabel's streamlines - internal representation (Voxel space) assume (0,0,0) to be in the + internal representation (Voxel space) assumes (0,0,0) to be in the center of the voxel. - Thus, streamlines are shifted of half a voxel on load and are shifted + Thus, streamlines are shifted by half a voxel on load and are shifted back on save. """ def __init__(self, fileobj): @@ -461,10 +465,10 @@ class TrkFile(TractogramFile): ---- TrackVis (so its file format: TRK) considers the streamline coordinate (0,0,0) to be in the corner of the voxel whereas NiBabel's streamlines - internal representation (Voxel space) assume (0,0,0) to be in the + internal representation (Voxel space) assumes (0,0,0) to be in the center of the voxel. - Thus, streamlines are shifted of half a voxel on load and are shifted + Thus, streamlines are shifted by half a voxel on load and are shifted back on save. """ @@ -515,8 +519,9 @@ def is_correct_format(cls, fileobj): otherwise returns False. """ with Opener(fileobj) as f: - magic_number = f.read(5) - f.seek(-5, os.SEEK_CUR) + magic_len = len(cls.MAGIC_NUMBER) + magic_number = f.read(magic_len) + f.seek(-magic_len, os.SEEK_CUR) return magic_number == cls.MAGIC_NUMBER @classmethod diff --git a/nibabel/streamlines/utils.py b/nibabel/streamlines/utils.py index 3a5c648cc9..0ef5b740ac 100644 --- a/nibabel/streamlines/utils.py +++ b/nibabel/streamlines/utils.py @@ -1,4 +1,3 @@ -import numpy as np import nibabel @@ -9,7 +8,7 @@ def get_affine_from_reference(ref): --------- ref : str or :class:`Nifti1Image` object or ndarray shape (4, 4) If str then it's the filename of reference file that will be loaded - using :func:nibabel.load in order to obtain the affine. + using :func:`nibabel.load` in order to obtain the affine. If :class:`Nifti1Image` object then the affine is obtained from it. If ndarray shape (4, 4) then it's the affine. @@ -18,14 +17,15 @@ def get_affine_from_reference(ref): affine : ndarray (4, 4) Transformation matrix mapping voxel space to RAS+mm space. """ - if type(ref) is np.ndarray: + if hasattr(ref, 'affine'): + return ref.affine + + if hasattr(ref, 'shape'): if ref.shape != (4, 4): msg = "`ref` needs to be a numpy array with shape (4, 4)!" raise ValueError(msg) return ref - elif hasattr(ref, 'affine'): - return ref.affine # Assume `ref` is the name of a neuroimaging file. return nibabel.load(ref).affine diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index 6868b18269..8e0cd982e5 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -65,19 +65,8 @@ def assert_allclose_safely(a, b, match_nans=True, rtol=1e-5, atol=1e-8): assert_true(np.allclose(a, b, rtol=rtol, atol=atol)) -def check_iteration(iterable): - """ Checks that an object can be iterated through without errors. """ - try: - for _ in iterable: - pass - except: - return False - - return True - - def assert_arrays_equal(arrays1, arrays2): - """ Checks that two iterables yielding arrays are equals. """ + """ Check two iterables yield the same sequence of arrays. """ for arr1, arr2 in zip_longest(arrays1, arrays2, fillvalue=None): assert_false(arr1 is None or arr2 is None) assert_array_equal(arr1, arr2) From c950b711e200b52bf0dee84d6ed704dccdf8d5aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 24 May 2016 09:17:41 -0400 Subject: [PATCH 108/135] Merge TrkWriter and TrkReader with TrkFile --- nibabel/streamlines/trk.py | 593 +++++++++++++++++++------------------ 1 file changed, 305 insertions(+), 288 deletions(-) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index c43bad02de..af733aa46b 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -173,289 +173,20 @@ def encode_value_in_name(value, name, max_name_len=20): return name -class TrkReader(object): - """ Convenience class to encapsulate TRK file format. - - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to TRK file (and ready to read from the beginning - of the TRK header) - - Note - ---- - TrackVis (so its file format: TRK) considers the streamline coordinate - (0,0,0) to be in the corner of the voxel whereas NiBabel's streamlines - internal representation (Voxel space) assumes (0,0,0) to be in the - center of the voxel. - - Thus, streamlines are shifted by half a voxel on load and are shifted - back on save. - """ - def __init__(self, fileobj): - self.fileobj = fileobj - - with Opener(self.fileobj) as f: - # Read the header in one block. - header_str = f.read(header_2_dtype.itemsize) - header_rec = np.fromstring(string=header_str, dtype=header_2_dtype) - - # Check endianness - endianness = native_code - if header_rec['hdr_size'] != TrkFile.HEADER_SIZE: - endianness = swapped_code - - # Swap byte order - header_rec = header_rec.newbyteorder() - if header_rec['hdr_size'] != TrkFile.HEADER_SIZE: - msg = "Invalid hdr_size: {0} instead of {1}" - raise HeaderError(msg.format(header_rec['hdr_size'], - TrkFile.HEADER_SIZE)) - - if header_rec['version'] == 1: - header_rec = np.fromstring(string=header_str, - dtype=header_1_dtype) - elif header_rec['version'] == 2: - pass # Nothing more to do. - else: - raise HeaderError('NiBabel only supports versions 1 and 2.') - - # Convert the first record of `header_rec` into a dictionnary - self.header = dict(zip(header_rec.dtype.names, header_rec[0])) - self.header[Field.ENDIANNESS] = endianness - - # If vox_to_ras[3][3] is 0, it means the matrix is not recorded. - if self.header[Field.VOXEL_TO_RASMM][3][3] == 0: - self.header[Field.VOXEL_TO_RASMM] = np.eye(4, dtype=np.float32) - warnings.warn(("Field 'vox_to_ras' in the TRK's header was" - " not recorded. Will continue assuming it's" - " the identity."), HeaderWarning) - - # Check that the 'vox_to_ras' affine is valid, i.e. should be - # able to determine the axis directions. - axcodes = aff2axcodes(self.header[Field.VOXEL_TO_RASMM]) - if None in axcodes: - msg = ("The 'vox_to_ras' affine is invalid! Could not" - " determine the axis directions from it.\n{0}" - ).format(self.header[Field.VOXEL_TO_RASMM]) - raise HeaderError(msg) - - # By default, the voxel order is LPS. - # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates - if self.header[Field.VOXEL_ORDER] == b"": - msg = ("Voxel order is not specified, will assume 'LPS' since" - "it is Trackvis software's default.") - warnings.warn(msg, HeaderWarning) - self.header[Field.VOXEL_ORDER] = b"LPS" - - # Keep the file position where the data begin. - self.offset_data = f.tell() - - def __iter__(self): - i4_dtype = np.dtype(self.header[Field.ENDIANNESS] + "i4") - f4_dtype = np.dtype(self.header[Field.ENDIANNESS] + "f4") - - with Opener(self.fileobj) as f: - start_position = f.tell() - - nb_pts_and_scalars = int(3 + - self.header[Field.NB_SCALARS_PER_POINT]) - pts_and_scalars_size = int(nb_pts_and_scalars * f4_dtype.itemsize) - nb_properties = self.header[Field.NB_PROPERTIES_PER_STREAMLINE] - properties_size = int(nb_properties * f4_dtype.itemsize) +def create_empty_header(): + """ Return an empty compliant TRK header. """ + header = np.zeros(1, dtype=header_2_dtype) - # Set the file position at the beginning of the data. - f.seek(self.offset_data, os.SEEK_SET) + # Default values + header[Field.MAGIC_NUMBER] = TrkFile.MAGIC_NUMBER + header[Field.VOXEL_SIZES] = np.array((1, 1, 1), dtype="f4") + header[Field.DIMENSIONS] = np.array((1, 1, 1), dtype="h") + header[Field.VOXEL_TO_RASMM] = np.eye(4, dtype="f4") + header[Field.VOXEL_ORDER] = b"RAS" + header['version'] = 2 + header['hdr_size'] = TrkFile.HEADER_SIZE - # If 'count' field is 0, i.e. not provided, we have to loop - # until the EOF. - nb_streamlines = self.header[Field.NB_STREAMLINES] - if nb_streamlines == 0: - nb_streamlines = np.inf - - count = 0 - nb_pts_dtype = i4_dtype.str[:-1] - while count < nb_streamlines: - nb_pts_str = f.read(i4_dtype.itemsize) - - # Check if we reached EOF - if len(nb_pts_str) == 0: - break - - # Read number of points of the next streamline. - nb_pts = struct.unpack(nb_pts_dtype, nb_pts_str)[0] - - # Read streamline's data - points_and_scalars = np.ndarray( - shape=(nb_pts, nb_pts_and_scalars), - dtype=f4_dtype, - buffer=f.read(nb_pts * pts_and_scalars_size)) - - points = points_and_scalars[:, :3] - scalars = points_and_scalars[:, 3:] - - # Read properties - properties = np.ndarray( - shape=(nb_properties,), - dtype=f4_dtype, - buffer=f.read(properties_size)) - - yield points, scalars, properties - count += 1 - - # In case the 'count' field was not provided. - self.header[Field.NB_STREAMLINES] = count - - # Set the file position where it was (in case it was already open). - f.seek(start_position, os.SEEK_CUR) - - -class TrkWriter(object): - @classmethod - def create_empty_header(cls): - """ Return an empty compliant TRK header. """ - header = np.zeros(1, dtype=header_2_dtype) - - # Default values - header[Field.MAGIC_NUMBER] = TrkFile.MAGIC_NUMBER - header[Field.VOXEL_SIZES] = np.array((1, 1, 1), dtype="f4") - header[Field.DIMENSIONS] = np.array((1, 1, 1), dtype="h") - header[Field.VOXEL_TO_RASMM] = np.eye(4, dtype="f4") - header[Field.VOXEL_ORDER] = b"RAS" - header['version'] = 2 - header['hdr_size'] = TrkFile.HEADER_SIZE - - return header - - def __init__(self, fileobj, header): - self.header = self.create_empty_header() - - # Override hdr's fields by those contained in `header`. - for k, v in header.items(): - if k in header_2_dtype.fields.keys(): - self.header[k] = v - - # By default, the voxel order is LPS. - # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates - if self.header[Field.VOXEL_ORDER] == b"": - self.header[Field.VOXEL_ORDER] = b"LPS" - - # Keep counts for correcting incoherent fields or warn. - self.nb_streamlines = 0 - self.nb_points = 0 - self.nb_scalars = 0 - self.nb_properties = 0 - - # Write header - self.header = self.header[0] - self.file = Opener(fileobj, mode="wb") - # Keep track of the beginning of the header. - self.beginning = self.file.tell() - - self.file.write(self.header.tostring()) - - def write(self, tractogram): - i4_dtype = np.dtype(" MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: - msg = ("Can only store {0} named data_per_streamline (also known" - " as 'properties' in the TRK format)." - ).format(MAX_NB_NAMED_SCALARS_PER_POINT) - raise ValueError(msg) - - data_for_streamline_keys = sorted(data_for_streamline.keys()) - property_name = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, - dtype='S20') - for i, name in enumerate(data_for_streamline_keys): - # Use the last to bytes of the name to store the number of values - # associated to this data_for_streamline. - nb_values = data_for_streamline[name].shape[-1] - property_name[i] = encode_value_in_name(nb_values, name) - self.header['property_name'][:] = property_name - - # Update field 'scalar_name' using 'tractogram.data_per_point'. - data_for_points = first_item.data_for_points - if len(data_for_points) > MAX_NB_NAMED_SCALARS_PER_POINT: - raise ValueError(("Can only store {0} named data_per_point (also" - " known as 'scalars' in the TRK format)." - ).format(MAX_NB_NAMED_SCALARS_PER_POINT)) - - data_for_points_keys = sorted(data_for_points.keys()) - scalar_name = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') - for i, name in enumerate(data_for_points_keys): - # Use the last two bytes of the name to store the number of values - # associated to this data_for_streamline. - nb_values = data_for_points[name].shape[-1] - scalar_name[i] = encode_value_in_name(nb_values, name) - self.header['scalar_name'][:] = scalar_name - - # Make sure streamlines are in rasmm then send them to voxmm. - tractogram = tractogram.to_world(lazy=True) - affine_to_trackvis = get_affine_rasmm_to_trackvis(self.header) - tractogram = tractogram.apply_affine(affine_to_trackvis, lazy=True) - - for t in tractogram: - if any((len(d) != len(t.streamline) - for d in t.data_for_points.values())): - raise DataError("Missing scalars for some points!") - - points = np.asarray(t.streamline, dtype=f4_dtype) - scalars = [np.asarray(t.data_for_points[k], dtype=f4_dtype) - for k in data_for_points_keys] - scalars = np.concatenate([np.ndarray((len(points), 0), - dtype=f4_dtype) - ] + scalars, axis=1) - properties = [np.asarray(t.data_for_streamline[k], dtype=f4_dtype) - for k in data_for_streamline_keys] - properties = np.concatenate([np.array([], dtype=f4_dtype) - ] + properties) - - data = struct.pack(i4_dtype.str[:-1], len(points)) - data += np.concatenate([points, scalars], axis=1).tostring() - data += properties.tostring() - self.file.write(data) - - self.nb_streamlines += 1 - self.nb_points += len(points) - self.nb_scalars += scalars.size - self.nb_properties += len(properties) - - # Use those values to update the header. - nb_scalars_per_point = self.nb_scalars / self.nb_points - nb_properties_per_streamline = self.nb_properties / self.nb_streamlines - - # Check for errors - if nb_scalars_per_point != int(nb_scalars_per_point): - msg = "Nb. of scalars differs from one point to another!" - raise DataError(msg) - - if nb_properties_per_streamline != int(nb_properties_per_streamline): - msg = "Nb. of properties differs from one streamline to another!" - raise DataError(msg) - - self.header[Field.NB_STREAMLINES] = self.nb_streamlines - self.header[Field.NB_SCALARS_PER_POINT] = nb_scalars_per_point - self.header[Field.NB_PROPERTIES_PER_STREAMLINE] = nb_properties_per_streamline - - # Overwrite header with updated one. - self.file.seek(self.beginning, os.SEEK_SET) - self.file.write(self.header.tostring()) + return header class TrkFile(TractogramFile): @@ -495,7 +226,7 @@ def __init__(self, tractogram, header=None): of the voxel. """ if header is None: - header_rec = TrkWriter.create_empty_header() + header_rec = create_empty_header() header = dict(zip(header_rec.dtype.names, header_rec[0])) super(TrkFile, self).__init__(tractogram, header) @@ -533,7 +264,8 @@ def load(cls, fileobj, lazy_load=False): fileobj : string or file-like object If string, a filename; otherwise an open file-like object pointing to TRK file (and ready to read from the beginning - of the TRK header). + of the TRK header). Note that calling this function + does not change the file position. lazy_load : {False, True}, optional If True, load streamlines in a lazy manner i.e. they will not be kept in memory. Otherwise, load all streamlines in memory. @@ -550,8 +282,7 @@ def load(cls, fileobj, lazy_load=False): and *mm* space where coordinate (0,0,0) refers to the center of the voxel. """ - trk_reader = TrkReader(fileobj) - hdr = trk_reader.header + hdr = cls._read_header(fileobj) # Find scalars and properties name data_per_point_slice = {} @@ -603,7 +334,7 @@ def load(cls, fileobj, lazy_load=False): if lazy_load: def _read(): - for pts, scals, props in trk_reader: + for pts, scals, props in cls._read(fileobj, hdr): items = data_per_point_slice.items() data_for_points = dict((k, scals[:, v]) for k, v in items) items = data_per_streamline_slice.items() @@ -615,6 +346,7 @@ def _read(): tractogram = LazyTractogram.create_from(_read) else: + trk_reader = cls._read(fileobj, hdr) arr_seqs = create_arraysequences_from_generator(trk_reader, n=3) streamlines, scalars, properties = arr_seqs properties = np.asarray(properties) # Actually a 2d array. @@ -641,8 +373,293 @@ def save(self, fileobj): pointing to TRK file (and ready to read from the beginning of the TRK header data). """ - trk_writer = TrkWriter(fileobj, self.header) - trk_writer.write(self.tractogram) + header = create_empty_header() + + # Override hdr's fields by those contained in `header`. + for k, v in self.header.items(): + if k in header_2_dtype.fields.keys(): + header[k] = v + + # By default, the voxel order is LPS. + # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates + if header[Field.VOXEL_ORDER] == b"": + header[Field.VOXEL_ORDER] = b"LPS" + + # Keep counts for correcting incoherent fields or warn. + nb_streamlines = 0 + nb_points = 0 + nb_scalars = 0 + nb_properties = 0 + + header = header[0] + with Opener(fileobj, mode="wb") as f: + # Keep track of the beginning of the header. + beginning = f.tell() + + f.write(header.tostring()) + + i4_dtype = np.dtype(" MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE: + msg = ("Can only store {0} named data_per_streamline (also" + " known as 'properties' in the TRK format)." + ).format(MAX_NB_NAMED_SCALARS_PER_POINT) + raise ValueError(msg) + + data_for_streamline_keys = sorted(data_for_streamline.keys()) + property_name = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, + dtype='S20') + for i, name in enumerate(data_for_streamline_keys): + # Use the last to bytes of the name to store the number of + # values associated to this data_for_streamline. + nb_values = data_for_streamline[name].shape[-1] + property_name[i] = encode_value_in_name(nb_values, name) + header['property_name'][:] = property_name + + # Update field 'scalar_name' using 'tractogram.data_per_point'. + data_for_points = first_item.data_for_points + if len(data_for_points) > MAX_NB_NAMED_SCALARS_PER_POINT: + msg = ("Can only store {0} named data_per_point (also known" + " as 'scalars' in the TRK format)." + ).format(MAX_NB_NAMED_SCALARS_PER_POINT) + raise ValueError(msg) + + data_for_points_keys = sorted(data_for_points.keys()) + scalar_name = np.zeros(MAX_NB_NAMED_SCALARS_PER_POINT, dtype='S20') + for i, name in enumerate(data_for_points_keys): + # Use the last two bytes of the name to store the number of + # values associated to this data_for_streamline. + nb_values = data_for_points[name].shape[-1] + scalar_name[i] = encode_value_in_name(nb_values, name) + header['scalar_name'][:] = scalar_name + + # Make sure streamlines are in rasmm then send them to voxmm. + tractogram = self.tractogram.to_world(lazy=True) + affine_to_trackvis = get_affine_rasmm_to_trackvis(header) + tractogram = tractogram.apply_affine(affine_to_trackvis, lazy=True) + + for t in tractogram: + if any((len(d) != len(t.streamline) + for d in t.data_for_points.values())): + raise DataError("Missing scalars for some points!") + + points = np.asarray(t.streamline, dtype=f4_dtype) + scalars = [np.asarray(t.data_for_points[k], dtype=f4_dtype) + for k in data_for_points_keys] + scalars = np.concatenate([np.ndarray((len(points), 0), + dtype=f4_dtype) + ] + scalars, axis=1) + properties = [np.asarray(t.data_for_streamline[k], + dtype=f4_dtype) + for k in data_for_streamline_keys] + properties = np.concatenate([np.array([], dtype=f4_dtype) + ] + properties) + + data = struct.pack(i4_dtype.str[:-1], len(points)) + data += np.concatenate([points, scalars], axis=1).tostring() + data += properties.tostring() + f.write(data) + + nb_streamlines += 1 + nb_points += len(points) + nb_scalars += scalars.size + nb_properties += len(properties) + + # Use those values to update the header. + nb_scalars_per_point = nb_scalars / nb_points + nb_properties_per_streamline = nb_properties / nb_streamlines + + # Check for errors + if nb_scalars_per_point != int(nb_scalars_per_point): + msg = "Nb. of scalars differs from one point to another!" + raise DataError(msg) + + if nb_properties_per_streamline != int(nb_properties_per_streamline): + msg = ("Nb. of properties differs from one streamline to" + " another!") + raise DataError(msg) + + header[Field.NB_STREAMLINES] = nb_streamlines + header[Field.NB_SCALARS_PER_POINT] = nb_scalars_per_point + header[Field.NB_PROPERTIES_PER_STREAMLINE] = nb_properties_per_streamline + + # Overwrite header with updated one. + f.seek(beginning, os.SEEK_SET) + f.write(header.tostring()) + + @staticmethod + def _read_header(fileobj): + """ Reads a TRK header from a file. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header). Note that calling this function + does not change the file position. + + Returns + ------- + header : dict + Metadata associated to this tractogram file. + """ + with Opener(fileobj) as f: + start_position = f.tell() + + # Read the header in one block. + header_str = f.read(header_2_dtype.itemsize) + header_rec = np.fromstring(string=header_str, dtype=header_2_dtype) + + # Check endianness + endianness = native_code + if header_rec['hdr_size'] != TrkFile.HEADER_SIZE: + endianness = swapped_code + + # Swap byte order + header_rec = header_rec.newbyteorder() + if header_rec['hdr_size'] != TrkFile.HEADER_SIZE: + msg = "Invalid hdr_size: {0} instead of {1}" + raise HeaderError(msg.format(header_rec['hdr_size'], + TrkFile.HEADER_SIZE)) + + if header_rec['version'] == 1: + header_rec = np.fromstring(string=header_str, + dtype=header_1_dtype) + elif header_rec['version'] == 2: + pass # Nothing more to do. + else: + raise HeaderError('NiBabel only supports versions 1 and 2.') + + # Convert the first record of `header_rec` into a dictionnary + header = dict(zip(header_rec.dtype.names, header_rec[0])) + header[Field.ENDIANNESS] = endianness + + # If vox_to_ras[3][3] is 0, it means the matrix is not recorded. + if header[Field.VOXEL_TO_RASMM][3][3] == 0: + header[Field.VOXEL_TO_RASMM] = np.eye(4, dtype=np.float32) + warnings.warn(("Field 'vox_to_ras' in the TRK's header was" + " not recorded. Will continue assuming it's" + " the identity."), HeaderWarning) + + # Check that the 'vox_to_ras' affine is valid, i.e. should be + # able to determine the axis directions. + axcodes = aff2axcodes(header[Field.VOXEL_TO_RASMM]) + if None in axcodes: + msg = ("The 'vox_to_ras' affine is invalid! Could not" + " determine the axis directions from it.\n{0}" + ).format(header[Field.VOXEL_TO_RASMM]) + raise HeaderError(msg) + + # By default, the voxel order is LPS. + # http://trackvis.org/blog/forum/diffusion-toolkit-usage/interpretation-of-track-point-coordinates + if header[Field.VOXEL_ORDER] == b"": + msg = ("Voxel order is not specified, will assume 'LPS' since" + "it is Trackvis software's default.") + warnings.warn(msg, HeaderWarning) + header[Field.VOXEL_ORDER] = b"LPS" + + # Keep the file position where the data begin. + header['_offset_data'] = f.tell() + + # Set the file position where it was (in case it was already open). + f.seek(start_position, os.SEEK_CUR) + + return header + + @staticmethod + def _read(fileobj, header): + """ Reads TRK data from a file. + + Parameters + ---------- + fileobj : string or file-like object + If string, a filename; otherwise an open file-like object + pointing to TRK file (and ready to read from the beginning + of the TRK header). Note that calling this function + does not change the file position. + header : dict + Metadata associated to this tractogram file. + + Yields + ------ + data : tuple of ndarrays + Streamline data: points, scalars, properties. + points: ndarray of shape (n_pts, 3) + scalars: ndarray of shape (n_pts, nb_scalars_per_per_point) + properties: ndarray of shape (nb_properties_per_per_point,) + """ + i4_dtype = np.dtype(header[Field.ENDIANNESS] + "i4") + f4_dtype = np.dtype(header[Field.ENDIANNESS] + "f4") + + with Opener(fileobj) as f: + start_position = f.tell() + + nb_pts_and_scalars = int(3 + + header[Field.NB_SCALARS_PER_POINT]) + pts_and_scalars_size = int(nb_pts_and_scalars * f4_dtype.itemsize) + nb_properties = header[Field.NB_PROPERTIES_PER_STREAMLINE] + properties_size = int(nb_properties * f4_dtype.itemsize) + + # Set the file position at the beginning of the data. + f.seek(header["_offset_data"], os.SEEK_SET) + + # If 'count' field is 0, i.e. not provided, we have to loop + # until the EOF. + nb_streamlines = header[Field.NB_STREAMLINES] + if nb_streamlines == 0: + nb_streamlines = np.inf + + count = 0 + nb_pts_dtype = i4_dtype.str[:-1] + while count < nb_streamlines: + nb_pts_str = f.read(i4_dtype.itemsize) + + # Check if we reached EOF + if len(nb_pts_str) == 0: + break + + # Read number of points of the next streamline. + nb_pts = struct.unpack(nb_pts_dtype, nb_pts_str)[0] + + # Read streamline's data + points_and_scalars = np.ndarray( + shape=(nb_pts, nb_pts_and_scalars), + dtype=f4_dtype, + buffer=f.read(nb_pts * pts_and_scalars_size)) + + points = points_and_scalars[:, :3] + scalars = points_and_scalars[:, 3:] + + # Read properties + properties = np.ndarray( + shape=(nb_properties,), + dtype=f4_dtype, + buffer=f.read(properties_size)) + + yield points, scalars, properties + count += 1 + + # In case the 'count' field was not provided. + header[Field.NB_STREAMLINES] = count + + # Set the file position where it was (in case it was already open). + f.seek(start_position, os.SEEK_CUR) def __str__(self): """ Gets a formatted string of the header of a TRK file. From ae62594f54579b3119829c465a977d962e25ed64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Thu, 26 May 2016 11:27:23 -0400 Subject: [PATCH 109/135] Refactored test_trk.py --- nibabel/streamlines/tests/test_tractogram.py | 5 +- nibabel/streamlines/tests/test_trk.py | 236 ++++++++++--------- 2 files changed, 123 insertions(+), 118 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 68d8543d13..5870693674 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -53,8 +53,9 @@ def setup(): 'mean_torsion': DATA['mean_torsion'], 'mean_colors': DATA['mean_colors']} - DATA['empty_tractogram'] = Tractogram() - DATA['simple_tractogram'] = Tractogram(DATA['streamlines']) + DATA['empty_tractogram'] = Tractogram(affine_to_rasmm=np.eye(4)) + DATA['simple_tractogram'] = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) DATA['tractogram'] = Tractogram(DATA['streamlines'], DATA['data_per_streamline'], DATA['data_per_point'], diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 7425947aad..87bceee36b 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -6,20 +6,80 @@ from nibabel.externals.six import BytesIO -from nibabel.testing import suppress_warnings, clear_and_catch_warnings -from nibabel.testing import assert_arrays_equal +from nibabel.testing import clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal from .test_tractogram import assert_tractogram_equal -from ..tractogram import Tractogram, LazyTractogram -from ..tractogram_file import DataError, HeaderError, HeaderWarning +from ..tractogram import Tractogram +from ..tractogram_file import HeaderError, HeaderWarning from .. import trk as trk_module from ..trk import TrkFile from ..header import Field DATA_PATH = pjoin(os.path.dirname(__file__), 'data') +DATA = {} + + +def setup(): + global DATA + + DATA['empty_trk_fname'] = pjoin(DATA_PATH, "empty.trk") + # simple.trk contains only streamlines + DATA['simple_trk_fname'] = pjoin(DATA_PATH, "simple.trk") + # standard.trk contains only streamlines + DATA['standard_trk_fname'] = pjoin(DATA_PATH, "standard.trk") + # standard.LPS.trk contains only streamlines + DATA['standard_LPS_trk_fname'] = pjoin(DATA_PATH, "standard.LPS.trk") + + # complex.trk contains streamlines, scalars and properties + DATA['complex_trk_fname'] = pjoin(DATA_PATH, "complex.trk") + DATA['complex_trk_big_endian_fname'] = pjoin(DATA_PATH, + "complex_big_endian.trk") + + DATA['streamlines'] = [np.arange(1*3, dtype="f4").reshape((1, 3)), + np.arange(2*3, dtype="f4").reshape((2, 3)), + np.arange(5*3, dtype="f4").reshape((5, 3))] + + DATA['fa'] = [np.array([[0.2]], dtype="f4"), + np.array([[0.3], + [0.4]], dtype="f4"), + np.array([[0.5], + [0.6], + [0.6], + [0.7], + [0.8]], dtype="f4")] + + DATA['colors'] = [np.array([(1, 0, 0)]*1, dtype="f4"), + np.array([(0, 1, 0)]*2, dtype="f4"), + np.array([(0, 0, 1)]*5, dtype="f4")] + + DATA['mean_curvature'] = [np.array([1.11], dtype="f4"), + np.array([2.11], dtype="f4"), + np.array([3.11], dtype="f4")] + + DATA['mean_torsion'] = [np.array([1.22], dtype="f4"), + np.array([2.22], dtype="f4"), + np.array([3.22], dtype="f4")] + + DATA['mean_colors'] = [np.array([1, 0, 0], dtype="f4"), + np.array([0, 1, 0], dtype="f4"), + np.array([0, 0, 1], dtype="f4")] + + DATA['data_per_point'] = {'colors': DATA['colors'], + 'fa': DATA['fa']} + DATA['data_per_streamline'] = {'mean_curvature': DATA['mean_curvature'], + 'mean_torsion': DATA['mean_torsion'], + 'mean_colors': DATA['mean_colors']} + + DATA['empty_tractogram'] = Tractogram(affine_to_rasmm=np.eye(4)) + DATA['simple_tractogram'] = Tractogram(DATA['streamlines'], + affine_to_rasmm=np.eye(4)) + DATA['complex_tractogram'] = Tractogram(DATA['streamlines'], + DATA['data_per_streamline'], + DATA['data_per_point'], + affine_to_rasmm=np.eye(4)) def assert_header_equal(h1, h2): @@ -32,87 +92,29 @@ def assert_header_equal(h1, h2): class TestTRK(unittest.TestCase): - def setUp(self): - self.empty_trk_filename = pjoin(DATA_PATH, "empty.trk") - # simple.trk contains only streamlines - self.simple_trk_filename = pjoin(DATA_PATH, "simple.trk") - # standard.trk contains only streamlines - self.standard_trk_filename = pjoin(DATA_PATH, "standard.trk") - # standard.LPS.trk contains only streamlines - self.standard_LPS_trk_filename = pjoin(DATA_PATH, - "standard.LPS.trk") - - # complex.trk contains streamlines, scalars and properties - self.complex_trk_filename = pjoin(DATA_PATH, "complex.trk") - self.complex_trk_big_endian_filename = pjoin(DATA_PATH, - "complex_big_endian.trk") - - self.streamlines = [np.arange(1*3, dtype="f4").reshape((1, 3)), - np.arange(2*3, dtype="f4").reshape((2, 3)), - np.arange(5*3, dtype="f4").reshape((5, 3))] - - self.fa = [np.array([[0.2]], dtype="f4"), - np.array([[0.3], - [0.4]], dtype="f4"), - np.array([[0.5], - [0.6], - [0.6], - [0.7], - [0.8]], dtype="f4")] - - self.colors = [np.array([(1, 0, 0)]*1, dtype="f4"), - np.array([(0, 1, 0)]*2, dtype="f4"), - np.array([(0, 0, 1)]*5, dtype="f4")] - - self.mean_curvature = [np.array([1.11], dtype="f4"), - np.array([2.11], dtype="f4"), - np.array([3.11], dtype="f4")] - - self.mean_torsion = [np.array([1.22], dtype="f4"), - np.array([2.22], dtype="f4"), - np.array([3.22], dtype="f4")] - - self.mean_colors = [np.array([1, 0, 0], dtype="f4"), - np.array([0, 1, 0], dtype="f4"), - np.array([0, 0, 1], dtype="f4")] - - self.data_per_point = {'colors': self.colors, - 'fa': self.fa} - self.data_per_streamline = {'mean_curvature': self.mean_curvature, - 'mean_torsion': self.mean_torsion, - 'mean_colors': self.mean_colors} - - self.empty_tractogram = Tractogram(affine_to_rasmm=np.eye(4)) - self.simple_tractogram = Tractogram(self.streamlines, - affine_to_rasmm=np.eye(4)) - self.complex_tractogram = Tractogram(self.streamlines, - self.data_per_streamline, - self.data_per_point, - affine_to_rasmm=np.eye(4)) - def test_load_empty_file(self): for lazy_load in [False, True]: - trk = TrkFile.load(self.empty_trk_filename, lazy_load=lazy_load) - assert_tractogram_equal(trk.tractogram, self.empty_tractogram) + trk = TrkFile.load(DATA['empty_trk_fname'], lazy_load=lazy_load) + assert_tractogram_equal(trk.tractogram, DATA['empty_tractogram']) def test_load_simple_file(self): for lazy_load in [False, True]: - trk = TrkFile.load(self.simple_trk_filename, lazy_load=lazy_load) - assert_tractogram_equal(trk.tractogram, self.simple_tractogram) + trk = TrkFile.load(DATA['simple_trk_fname'], lazy_load=lazy_load) + assert_tractogram_equal(trk.tractogram, DATA['simple_tractogram']) def test_load_complex_file(self): for lazy_load in [False, True]: - trk = TrkFile.load(self.complex_trk_filename, lazy_load=lazy_load) - assert_tractogram_equal(trk.tractogram, self.complex_tractogram) + trk = TrkFile.load(DATA['complex_trk_fname'], lazy_load=lazy_load) + assert_tractogram_equal(trk.tractogram, DATA['complex_tractogram']) def test_load_file_with_wrong_information(self): - trk_file = open(self.simple_trk_filename, 'rb').read() + trk_file = open(DATA['simple_trk_fname'], 'rb').read() # Simulate a TRK file where `count` was not provided. count = np.array(0, dtype="int32").tostring() new_trk_file = trk_file[:1000-12] + count + trk_file[1000-8:] trk = TrkFile.load(BytesIO(new_trk_file), lazy_load=False) - assert_tractogram_equal(trk.tractogram, self.simple_tractogram) + assert_tractogram_equal(trk.tractogram, DATA['simple_tractogram']) # Simulate a TRK where `vox_to_ras` is not recorded (i.e. all zeros). vox_to_ras = np.zeros((4, 4), dtype=np.float32).tostring() @@ -152,18 +154,18 @@ def test_load_file_with_wrong_information(self): assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) def test_load_complex_file_in_big_endian(self): - trk_file = open(self.complex_trk_big_endian_filename, 'rb').read() + trk_file = open(DATA['complex_trk_big_endian_fname'], 'rb').read() # We use hdr_size as an indicator of little vs big endian. hdr_size_big_endian = np.array(1000, dtype=">i4").tostring() assert_equal(trk_file[996:996+4], hdr_size_big_endian) for lazy_load in [False, True]: - trk = TrkFile.load(self.complex_trk_big_endian_filename, + trk = TrkFile.load(DATA['complex_trk_big_endian_fname'], lazy_load=lazy_load) - assert_tractogram_equal(trk.tractogram, self.complex_tractogram) + assert_tractogram_equal(trk.tractogram, DATA['complex_tractogram']) def test_tractogram_file_properties(self): - trk = TrkFile.load(self.simple_trk_filename) + trk = TrkFile.load(DATA['simple_trk_fname']) assert_equal(trk.streamlines, trk.tractogram.streamlines) assert_equal(trk.get_streamlines(), trk.streamlines) assert_equal(trk.get_tractogram(), trk.tractogram) @@ -181,15 +183,15 @@ def test_write_empty_file(self): new_trk = TrkFile.load(trk_file) assert_tractogram_equal(new_trk.tractogram, tractogram) - new_trk_orig = TrkFile.load(self.empty_trk_filename) + new_trk_orig = TrkFile.load(DATA['empty_trk_fname']) assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), - open(self.empty_trk_filename, 'rb').read()) + open(DATA['empty_trk_fname'], 'rb').read()) def test_write_simple_file(self): - tractogram = Tractogram(self.streamlines, + tractogram = Tractogram(DATA['streamlines'], affine_to_rasmm=np.eye(4)) trk_file = BytesIO() @@ -200,17 +202,17 @@ def test_write_simple_file(self): new_trk = TrkFile.load(trk_file) assert_tractogram_equal(new_trk.tractogram, tractogram) - new_trk_orig = TrkFile.load(self.simple_trk_filename) + new_trk_orig = TrkFile.load(DATA['simple_trk_fname']) assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), - open(self.simple_trk_filename, 'rb').read()) + open(DATA['simple_trk_fname'], 'rb').read()) def test_write_complex_file(self): # With scalars - tractogram = Tractogram(self.streamlines, - data_per_point=self.data_per_point, + tractogram = Tractogram(DATA['streamlines'], + data_per_point=DATA['data_per_point'], affine_to_rasmm=np.eye(4)) trk_file = BytesIO() @@ -222,8 +224,9 @@ def test_write_complex_file(self): assert_tractogram_equal(new_trk.tractogram, tractogram) # With properties - tractogram = Tractogram(self.streamlines, - data_per_streamline=self.data_per_streamline, + data_per_streamline = DATA['data_per_streamline'] + tractogram = Tractogram(DATA['streamlines'], + data_per_streamline=data_per_streamline, affine_to_rasmm=np.eye(4)) trk = TrkFile(tractogram) @@ -235,9 +238,10 @@ def test_write_complex_file(self): assert_tractogram_equal(new_trk.tractogram, tractogram) # With scalars and properties - tractogram = Tractogram(self.streamlines, - data_per_point=self.data_per_point, - data_per_streamline=self.data_per_streamline, + data_per_streamline = DATA['data_per_streamline'] + tractogram = Tractogram(DATA['streamlines'], + data_per_point=DATA['data_per_point'], + data_per_streamline=data_per_streamline, affine_to_rasmm=np.eye(4)) trk_file = BytesIO() @@ -248,31 +252,31 @@ def test_write_complex_file(self): new_trk = TrkFile.load(trk_file, lazy_load=False) assert_tractogram_equal(new_trk.tractogram, tractogram) - new_trk_orig = TrkFile.load(self.complex_trk_filename) + new_trk_orig = TrkFile.load(DATA['complex_trk_fname']) assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), - open(self.complex_trk_filename, 'rb').read()) + open(DATA['complex_trk_fname'], 'rb').read()) def test_load_write_file(self): - for filename in [self.empty_trk_filename, - self.simple_trk_filename, - self.complex_trk_filename]: + for fname in [DATA['empty_trk_fname'], + DATA['simple_trk_fname'], + DATA['complex_trk_fname']]: for lazy_load in [False, True]: - trk = TrkFile.load(filename, lazy_load=lazy_load) + trk = TrkFile.load(fname, lazy_load=lazy_load) trk_file = BytesIO() trk.save(trk_file) - new_trk = TrkFile.load(filename, lazy_load=False) + new_trk = TrkFile.load(fname, lazy_load=False) assert_tractogram_equal(new_trk.tractogram, trk.tractogram) trk_file.seek(0, os.SEEK_SET) def test_load_write_LPS_file(self): # Load the RAS and LPS version of the standard. - trk_RAS = TrkFile.load(self.standard_trk_filename, lazy_load=False) - trk_LPS = TrkFile.load(self.standard_LPS_trk_filename, lazy_load=False) + trk_RAS = TrkFile.load(DATA['standard_trk_fname'], lazy_load=False) + trk_LPS = TrkFile.load(DATA['standard_LPS_trk_fname'], lazy_load=False) assert_tractogram_equal(trk_LPS.tractogram, trk_RAS.tractogram) # Write back the standard. @@ -286,12 +290,12 @@ def test_load_write_LPS_file(self): assert_header_equal(new_trk.header, trk.header) assert_tractogram_equal(new_trk.tractogram, trk.tractogram) - new_trk_orig = TrkFile.load(self.standard_LPS_trk_filename) + new_trk_orig = TrkFile.load(DATA['standard_LPS_trk_fname']) assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), - open(self.standard_LPS_trk_filename, 'rb').read()) + open(DATA['standard_LPS_trk_fname'], 'rb').read()) # Test writing a file where the header is missing the # Field.VOXEL_ORDER. @@ -310,12 +314,12 @@ def test_load_write_LPS_file(self): assert_header_equal(new_trk.header, trk_LPS.header) assert_tractogram_equal(new_trk.tractogram, trk.tractogram) - new_trk_orig = TrkFile.load(self.standard_LPS_trk_filename) + new_trk_orig = TrkFile.load(DATA['standard_LPS_trk_fname']) assert_tractogram_equal(new_trk.tractogram, new_trk_orig.tractogram) trk_file.seek(0, os.SEEK_SET) assert_equal(trk_file.read(), - open(self.standard_LPS_trk_filename, 'rb').read()) + open(DATA['standard_LPS_trk_fname'], 'rb').read()) def test_write_optional_header_fields(self): # The TRK file format doesn't support additional header fields. @@ -335,9 +339,9 @@ def test_write_too_many_scalars_and_properties(self): # TRK supports up to 10 data_per_point. data_per_point = {} for i in range(10): - data_per_point['#{0}'.format(i)] = self.fa + data_per_point['#{0}'.format(i)] = DATA['fa'] - tractogram = Tractogram(self.streamlines, + tractogram = Tractogram(DATA['streamlines'], data_per_point=data_per_point, affine_to_rasmm=np.eye(4)) @@ -350,9 +354,9 @@ def test_write_too_many_scalars_and_properties(self): assert_tractogram_equal(new_trk.tractogram, tractogram) # More than 10 data_per_point should raise an error. - data_per_point['#{0}'.format(i+1)] = self.fa + data_per_point['#{0}'.format(i+1)] = DATA['fa'] - tractogram = Tractogram(self.streamlines, + tractogram = Tractogram(DATA['streamlines'], data_per_point=data_per_point, affine_to_rasmm=np.eye(4)) @@ -362,9 +366,9 @@ def test_write_too_many_scalars_and_properties(self): # TRK supports up to 10 data_per_streamline. data_per_streamline = {} for i in range(10): - data_per_streamline['#{0}'.format(i)] = self.mean_torsion + data_per_streamline['#{0}'.format(i)] = DATA['mean_torsion'] - tractogram = Tractogram(self.streamlines, + tractogram = Tractogram(DATA['streamlines'], data_per_streamline=data_per_streamline, affine_to_rasmm=np.eye(4)) @@ -377,9 +381,9 @@ def test_write_too_many_scalars_and_properties(self): assert_tractogram_equal(new_trk.tractogram, tractogram) # More than 10 data_per_streamline should raise an error. - data_per_streamline['#{0}'.format(i+1)] = self.mean_torsion + data_per_streamline['#{0}'.format(i+1)] = DATA['mean_torsion'] - tractogram = Tractogram(self.streamlines, + tractogram = Tractogram(DATA['streamlines'], data_per_streamline=data_per_streamline) trk = TrkFile(tractogram) @@ -392,8 +396,8 @@ def test_write_scalars_and_properties_name_too_long(self): # So in reality we allow name of 18 characters, otherwise # the name is truncated and warning is issue. for nb_chars in range(22): - data_per_point = {'A'*nb_chars: self.colors} - tractogram = Tractogram(self.streamlines, + data_per_point = {'A'*nb_chars: DATA['colors']} + tractogram = Tractogram(DATA['streamlines'], data_per_point=data_per_point, affine_to_rasmm=np.eye(4)) @@ -403,8 +407,8 @@ def test_write_scalars_and_properties_name_too_long(self): else: trk.save(BytesIO()) - data_per_point = {'A'*nb_chars: self.fa} - tractogram = Tractogram(self.streamlines, + data_per_point = {'A'*nb_chars: DATA['fa']} + tractogram = Tractogram(DATA['streamlines'], data_per_point=data_per_point, affine_to_rasmm=np.eye(4)) @@ -420,8 +424,8 @@ def test_write_scalars_and_properties_name_too_long(self): # So in reality we allow name of 18 characters, otherwise # the name is truncated and warning is issue. for nb_chars in range(22): - data_per_streamline = {'A'*nb_chars: self.mean_colors} - tractogram = Tractogram(self.streamlines, + data_per_streamline = {'A'*nb_chars: DATA['mean_colors']} + tractogram = Tractogram(DATA['streamlines'], data_per_streamline=data_per_streamline, affine_to_rasmm=np.eye(4)) @@ -431,8 +435,8 @@ def test_write_scalars_and_properties_name_too_long(self): else: trk.save(BytesIO()) - data_per_streamline = {'A'*nb_chars: self.mean_torsion} - tractogram = Tractogram(self.streamlines, + data_per_streamline = {'A'*nb_chars: DATA['mean_torsion']} + tractogram = Tractogram(DATA['streamlines'], data_per_streamline=data_per_streamline, affine_to_rasmm=np.eye(4)) @@ -443,5 +447,5 @@ def test_write_scalars_and_properties_name_too_long(self): trk.save(BytesIO()) def test_str(self): - trk = TrkFile.load(self.complex_trk_filename) + trk = TrkFile.load(DATA['complex_trk_fname']) str(trk) # Simply test it's not failing when called. From 17bd66c0338a0fcf2fddff0bf272d9e0de9f57de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Sat, 28 May 2016 13:25:35 -0400 Subject: [PATCH 110/135] Addressed @jchoude's comments. --- nibabel/streamlines/__init__.py | 3 +- nibabel/streamlines/tests/test_tractogram.py | 26 ++++++++++ nibabel/streamlines/tractogram.py | 53 ++++++++++++-------- 3 files changed, 60 insertions(+), 22 deletions(-) diff --git a/nibabel/streamlines/__init__.py b/nibabel/streamlines/__init__.py index b4c740024e..124bdf8f50 100644 --- a/nibabel/streamlines/__init__.py +++ b/nibabel/streamlines/__init__.py @@ -70,7 +70,8 @@ def load(fileobj, lazy_load=False): of the streamlines file's header). lazy_load : {False, True}, optional If True, load streamlines in a lazy manner i.e. they will not be kept - in memory. Otherwise, load all streamlines in memory. + in memory and only be loaded when needed. + Otherwise, load all streamlines in memory. Returns ------- diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 5870693674..8e8f046358 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -334,6 +334,26 @@ def test_tractogram_creation(self): assert_raises(ValueError, Tractogram, DATA['streamlines'], data_per_point=data_per_point) + def test_setting_affine_to_rasmm(self): + tractogram = DATA['tractogram'].copy() + affine = np.diag(range(4)) + + # Test assigning None. + tractogram.affine_to_rasmm = None + assert_true(tractogram.affine_to_rasmm is None) + + # Test assigning a valid ndarray (should make a copy). + tractogram.affine_to_rasmm = affine + assert_true(tractogram.affine_to_rasmm is not affine) + + # Test assigning a list of lists. + tractogram.affine_to_rasmm = affine.tolist() + assert_array_equal(tractogram.affine_to_rasmm, affine) + + # Test assigning a ndarray with wrong shape. + assert_raises(ValueError, setattr, tractogram, + "affine_to_rasmm", affine[::2]) + def test_tractogram_getitem(self): # Retrieve TractogramItem by their index. for i, t in enumerate(DATA['tractogram']): @@ -483,6 +503,12 @@ def test_tractogram_apply_affine(self): for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): assert_array_almost_equal(s1, s2) + # Test applying the identity transformation. + tractogram = DATA['tractogram'].copy() + tractogram.apply_affine(np.eye(4)) + for s1, s2 in zip(tractogram.streamlines, DATA['streamlines']): + assert_array_almost_equal(s1, s2) + # Test removing affine_to_rasmm tractogram = DATA['tractogram'].copy() tractogram.affine_to_rasmm = None diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index a977238b76..754ef69aad 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -188,7 +188,7 @@ class TractogramItem(object): where N is the number of points. data_for_streamline : dict Dictionary containing some data associated to this particular - streamline. Each key `k` is mapped to a ndarray of shape (Pk,), where + streamline. Each key `k` is mapped to a ndarray of shape (Pt,), where `Pt` is the dimension of the data associated with key `k`. data_for_points : dict Dictionary containing some data associated to each point of this @@ -225,17 +225,17 @@ class Tractogram(object): streamline $t$. data_per_streamline : :class:`PerArrayDict` object Dictionary where the items are (str, 2D array). - Each key represents an information $i$ to be kept along side every + Each key represents an information $i$ to be kept alongside every streamline, and its associated value is a 2D array of shape ($T$, $P_i$) where $T$ is the number of streamlines and $P_i$ is - the number scalar values to store for that particular information $i$. + the number of values to store for that particular information $i$. data_per_point : :class:`PerArraySequenceDict` object Dictionary where the items are (str, :class:`ArraySequence`). - Each key represents an information $i$ to be kept along side every + Each key represents an information $i$ to be kept alongside every point of every streamline, and its associated value is an iterable of ndarrays of shape ($N_t$, $M_i$) where $N_t$ is the number of points for a particular streamline $t$ and $M_i$ is the number - scalar values to store for that particular information $i$. + values to store for that particular information $i$. """ def __init__(self, streamlines=None, data_per_streamline=None, @@ -250,13 +250,13 @@ def __init__(self, streamlines=None, streamline $t$. data_per_streamline : dict of iterable of ndarrays, optional Dictionary where the items are (str, iterable). - Each key represents an information $i$ to be kept along side every + Each key represents an information $i$ to be kept alongside every streamline, and its associated value is an iterable of ndarrays of - shape ($P_i$,) where $P_i$ is the number scalar values to store + shape ($P_i$,) where $P_i$ is the number of scalar values to store for that particular information $i$. data_per_point : dict of iterable of ndarrays, optional Dictionary where the items are (str, iterable). - Each key represents an information $i$ to be kept along side every + Each key represents an information $i$ to be kept alongside every point of every streamline, and its associated value is an iterable of ndarrays of shape ($N_t$, $M_i$) where $N_t$ is the number of points for a particular streamline $t$ and $M_i$ is the number @@ -303,6 +303,13 @@ def affine_to_rasmm(self): @affine_to_rasmm.setter def affine_to_rasmm(self, value): + if value is not None: + value = np.array(value) + if value.shape != (4, 4): + msg = ("Affine matrix has a shape of (4, 4) but a ndarray with" + "shape {} was provided instead.").format(value.shape) + raise ValueError(msg) + self._affine_to_rasmm = value def __iter__(self): @@ -362,6 +369,9 @@ def apply_affine(self, affine, lazy=False): if len(self.streamlines) == 0: return self + if np.all(affine == np.eye(4)): + return self # No transformation. + BUFFER_SIZE = 10000000 # About 128 Mb since pts shape is 3. for start in range(0, len(self.streamlines.data), BUFFER_SIZE): end = start + BUFFER_SIZE @@ -408,7 +418,7 @@ class LazyTractogram(Tractogram): This container behaves lazily as it uses generator functions to manage streamlines and their data information. This container is thus memory - friendly since it doesn't require having all those data loaded in memory. + friendly since it doesn't require having all this data loaded in memory. Streamlines of a lazy tractogram can be in any coordinate system of your choice as long as you provide the correct `affine_to_rasmm` matrix, at @@ -424,23 +434,26 @@ class LazyTractogram(Tractogram): streamline $t$. data_per_streamline : :class:`LazyDict` object Dictionary where the items are (str, instantiated generator). - Each key represents an information $i$ to be kept along side every + Each key represents an information $i$ to be kept alongside every streamline, and its associated value is a generator function yielding that information via ndarrays of shape ($P_i$,) where - $P_i$ is the number scalar values to store for that particular + $P_i$ is the number of values to store for that particular information $i$. data_per_point : :class:`LazyDict` object Dictionary where the items are (str, instantiated generator). - Each key represents an information $i$ to be kept along side every + Each key represents an information $i$ to be kept alongside every point of every streamline, and its associated value is a generator function yielding that information via ndarrays of shape ($N_t$, $M_i$) where $N_t$ is the number of points for a particular - streamline $t$ and $M_i$ is the number scalar values to store for + streamline $t$ and $M_i$ is the number of values to store for that particular information $i$. Notes ----- LazyTractogram objects do not support indexing currently. + LazyTractogram objects are suited for operations that can be linearized + such as applying an affine transformation or converting streamlines from + one file format to another. """ def __init__(self, streamlines=None, data_per_streamline=None, @@ -455,18 +468,18 @@ def __init__(self, streamlines=None, streamline $t$. data_per_streamline : dict of generator functions, optional Dictionary where the items are (str, generator function). - Each key represents an information $i$ to be kept along side every + Each key represents an information $i$ to be kept alongside every streamline, and its associated value is a generator function yielding that information via ndarrays of shape ($P_i$,) where - $P_i$ is the number scalar values to store for that particular + $P_i$ is the number of values to store for that particular information $i$. data_per_point : dict of generator functions, optional Dictionary where the items are (str, generator function). - Each key represents an information $i$ to be kept along side every + Each key represents an information $i$ to be kept alongside every point of every streamline, and its associated value is a generator function yielding that information via ndarrays of shape ($N_t$, $M_i$) where $N_t$ is the number of points for a particular - streamline $t$ and $M_i$ is the number scalar values to store for + streamline $t$ and $M_i$ is the number of values to store for that particular information $i$. affine_to_rasmm : ndarray of shape (4, 4) or None, optional Transformation matrix that brings the streamlines contained in @@ -525,7 +538,7 @@ def create_from(cls, data_func): Parameters ---------- data_func : generator function yielding :class:`TractogramItem` objects - Generator function that whenever it is called starts yielding + Generator function that whenever is called starts yielding :class:`TractogramItem` objects that will be used to instantiate a :class:`LazyTractogram`. @@ -650,9 +663,7 @@ def __len__(self): warn("Number of streamlines will be determined manually by looping" " through the streamlines. If you know the actual number of" " streamlines, you might want to set it beforehand via" - " `self.header.nb_streamlines`." - " Note this will consume any generators used to create this" - " `LazyTractogram` object.", Warning) + " `self.header.nb_streamlines`.", Warning) # Count the number of streamlines. self._nb_streamlines = sum(1 for _ in self.streamlines) From be485c6cb863676af51255071a637fde9091e4cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 30 May 2016 23:08:44 -0400 Subject: [PATCH 111/135] Improved speed for loading TRK --- Changelog | 9 +- nibabel/benchmarks/bench_streamlines.py | 80 ++++++++++++ nibabel/streamlines/array_sequence.py | 123 +++++++++++++----- .../streamlines/tests/test_array_sequence.py | 31 +++++ nibabel/streamlines/trk.py | 2 +- 5 files changed, 206 insertions(+), 39 deletions(-) create mode 100644 nibabel/benchmarks/bench_streamlines.py diff --git a/Changelog b/Changelog index 8a7f93b687..a1e6bdf299 100644 --- a/Changelog +++ b/Changelog @@ -36,6 +36,9 @@ References like "pr/298" refer to github pull request numbers. are raising a DataError if the track is truncated when ``strict=True`` (the default), rather than a TypeError when trying to create the points array. + * New API for managing streamlines and their different file formats. This + adds a new module ``nibabel.streamlines`` that will eventually deprecate + the current trackvis reader found in ``nibabel.trackvis``. * 2.0.2 (Monday 23 November 2015) @@ -251,7 +254,7 @@ References like "pr/298" refer to github pull request numbers. the ability to transform to the image with data closest to the cononical image orientation (first axis left-to-right, second back-to-front, third down-to-up) (MB, Jonathan Taylor) - * Gifti format read and write support (preliminary) (Stephen Gerhard) + * Gifti format read and write support (preliminary) (Stephen Gerhard) * Added utilities to use nipy-style data packages, by rip then edit of nipy data package code (MB) * Some improvements to release support (Jarrod Millman, MB, Fernando Perez) @@ -469,7 +472,7 @@ visiting the URL:: * Removed functionality for "NiftiImage.save() raises an IOError exception when writing the image file fails." (Yaroslav Halchenko) - * Added ability to force a filetype when setting the filename or saving + * Added ability to force a filetype when setting the filename or saving a file. * Reverse the order of the 'header' and 'load' argument in the NiftiImage constructor. 'header' is now first as it seems to be used more often. @@ -481,7 +484,7 @@ visiting the URL:: * 0.20070301.2 (Thu, 1 Mar 2007) - * Fixed wrong link to the source tarball in README.html. + * Fixed wrong link to the source tarball in README.html. * 0.20070301.1 (Thu, 1 Mar 2007) diff --git a/nibabel/benchmarks/bench_streamlines.py b/nibabel/benchmarks/bench_streamlines.py new file mode 100644 index 0000000000..37c161e1d8 --- /dev/null +++ b/nibabel/benchmarks/bench_streamlines.py @@ -0,0 +1,80 @@ +""" Benchmarks for load and save of streamlines + +Run benchmarks with:: + + import nibabel as nib + nib.bench() + +If you have doctests enabled by default in nose (with a noserc file or +environment variable), and you have a numpy version <= 1.6.1, this will also run +the doctests, let's hope they pass. + +Run this benchmark with: + + nosetests -s --match '(?:^|[\\b_\\.//-])[Bb]ench' /path/to/bench_streamlines.py +""" +from __future__ import division, print_function + +import numpy as np + +from nibabel.externals.six.moves import zip +from nibabel.tmpdirs import InTemporaryDirectory + +from numpy.testing import assert_array_equal +from nibabel.streamlines import Tractogram +from nibabel.streamlines import TrkFile + +import nibabel as nib +import nibabel.trackvis as tv + +from numpy.testing import measure + + +def bench_load_trk(): + rng = np.random.RandomState(42) + dtype = 'float32' + NB_STREAMLINES = 5000 + NB_POINTS = 1000 + points = [rng.rand(NB_POINTS, 3).astype(dtype) + for i in range(NB_STREAMLINES)] + scalars = [rng.rand(NB_POINTS, 10).astype(dtype) + for i in range(NB_STREAMLINES)] + + repeat = 10 + + with InTemporaryDirectory(): + trk_file = "tmp.trk" + tractogram = Tractogram(points, affine_to_rasmm=np.eye(4)) + TrkFile(tractogram).save(trk_file) + + loaded_streamlines_old = [d[0]-0.5 for d in tv.read(trk_file, points_space="rasmm")[0]] + mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat) + print("Old: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) + + loaded_streamlines_new = nib.streamlines.load(trk_file, lazy_load=False).streamlines + mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)', repeat) + print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) + print("Speedup of %2f" % (mtime_old/mtime_new)) + + for s1, s2 in zip(loaded_streamlines_new, loaded_streamlines_old): + assert_array_equal(s1, s2) + + # Points and scalars + with InTemporaryDirectory(): + + trk_file = "tmp.trk" + tractogram = Tractogram(points, + data_per_point={'scalars': scalars}, + affine_to_rasmm=np.eye(4)) + TrkFile(tractogram).save(trk_file) + + mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat) + print("Old: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old)) + + mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)', repeat) + print("New: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new)) + print("Speedup of %2f" % (mtime_old/mtime_new)) + + +if __name__ == '__main__': + bench_load_trk() diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index adfd98439e..d295e47d1b 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -60,40 +60,83 @@ def __init__(self, iterable=None, buffer_size=4): self._is_view = True return - # Add elements of the iterable. + try: + # If possible try pre-allocating memory. + if len(iterable) > 0: + first_element = np.asarray(iterable[0]) + n_elements = np.sum([len(iterable[i]) + for i in range(len(iterable))]) + new_shape = (n_elements,) + first_element.shape[1:] + self._data = np.empty(new_shape, dtype=first_element.dtype) + except TypeError: + pass + + # Initialize the `ArraySequence` object from iterable's item. + coroutine = self._extend_using_coroutine() + coroutine.send(None) # Run until the first yield. + + for e in iterable: + coroutine.send(e) + + coroutine.close() # Terminate coroutine. + + def _extend_using_coroutine(self, buffer_size=4): + """ Creates a coroutine allowing to append elements. + + Parameters + ---------- + buffer_size : float, optional + Size (in Mb) for memory pre-allocation. + + Returns + ------- + coroutine + Coroutine object which expects the values to be appended to this + array sequence. + + Notes + ----- + This method is essential for + :func:`create_arraysequences_from_generator` as it allows for an + efficient way of creating multiple array sequences in a hyperthreaded + fashion and still benefit from the memory buffering. Whitout this + method the alternative would be to use :meth:`append` which does + not have such buffering mechanism and thus is at least one order of + magnitude slower. + """ offsets = [] lengths = [] - # Initialize the `ArraySequence` object from iterable's item. - offset = 0 - for i, e in enumerate(iterable): - e = np.asarray(e) - if i == 0: - try: - n_elements = np.sum([len(iterable[i]) - for i in range(len(iterable))]) - new_shape = (n_elements,) + e.shape[1:] - except TypeError: - # Can't get the number of elements in iterable. So, - # we use a memory buffer while building the ArraySequence. + + offset = 0 if len(self) == 0 else self._offsets[-1] + self._lengths[-1] + try: + first_element = True + while True: + e = (yield) + e = np.asarray(e) + if first_element: + first_element = False n_rows_buffer = int(buffer_size * 1024**2 // e.nbytes) new_shape = (n_rows_buffer,) + e.shape[1:] + if len(self) == 0: + self._data = np.empty(new_shape, dtype=e.dtype) - self._data = np.empty(new_shape, dtype=e.dtype) + end = offset + len(e) + if end > len(self._data): + # Resize needed, adding `len(e)` items plus some buffer. + nb_points = len(self._data) + nb_points += len(e) + n_rows_buffer + self._data.resize((nb_points,) + self.common_shape) - end = offset + len(e) - if end > len(self._data): - # Resize needed, adding `len(e)` items plus some buffer. - nb_points = len(self._data) - nb_points += len(e) + n_rows_buffer - self._data.resize((nb_points,) + self.common_shape) + offsets.append(offset) + lengths.append(len(e)) + self._data[offset:offset + len(e)] = e + offset += len(e) - offsets.append(offset) - lengths.append(len(e)) - self._data[offset:offset + len(e)] = e - offset += len(e) + except GeneratorExit: + pass - self._offsets = np.asarray(offsets) - self._lengths = np.asarray(lengths) + self._offsets = np.concatenate([self._offsets, offsets], axis=0) + self._lengths = np.concatenate([self._lengths, lengths], axis=0) # Clear unused memory. self._data.resize((offset,) + self.common_shape) @@ -266,13 +309,6 @@ def __getitem__(self, idx): seq._is_view = True return seq - # for name, slice_ in data_per_point_slice.items(): - # seq = ArraySequence() - # seq._data = scalars._data[:, slice_] - # seq._offsets = scalars._offsets - # seq._lengths = scalars._lengths - # tractogram.data_per_point[name] = seq - raise TypeError("Index must be either an int, a slice, a list of int" " or a ndarray of bool! Not " + str(type(idx))) @@ -320,10 +356,27 @@ def load(cls, filename): def create_arraysequences_from_generator(gen, n): """ Creates :class:`ArraySequence` objects from a generator yielding tuples + + Parameters + ---------- + gen : generator + Generator yielding a size `n` tuple containing the values to put in the + array sequences. + n : int + Number of :class:`ArraySequences` object to create. """ seqs = [ArraySequence() for _ in range(n)] + coroutines = [seq._extend_using_coroutine() for seq in seqs] + + for coroutine in coroutines: + coroutine.send(None) + for data in gen: - for i, seq in enumerate(seqs): - seq.append(data[i]) + for i, coroutine in enumerate(coroutines): + if data[i].nbytes > 0: + coroutine.send(data[i]) + + for coroutine in coroutines: + coroutine.close() return seqs diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index b0f3d81708..ed0ae84c05 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -200,6 +200,37 @@ def test_arraysequence_extend(self): seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. assert_raises(ValueError, seq.extend, data) + def test_arraysequence_extend_using_coroutine(self): + new_data = generate_data(nb_arrays=10, + common_shape=SEQ_DATA['seq'].common_shape, + rng=SEQ_DATA['rng']) + + # Extend with an empty list. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + coroutine = seq._extend_using_coroutine() + coroutine.send(None) + coroutine.close() + check_arr_seq(seq, SEQ_DATA['data']) + + # Extend with a list of ndarrays. + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + coroutine = seq._extend_using_coroutine() + coroutine.send(None) + for e in new_data: + coroutine.send(e) + coroutine.close() + check_arr_seq(seq, SEQ_DATA['data'] + new_data) + + # Extend with elements of different shape. + data = generate_data(nb_arrays=10, + common_shape=SEQ_DATA['seq'].common_shape*2, + rng=SEQ_DATA['rng']) + seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. + + coroutine = seq._extend_using_coroutine() + coroutine.send(None) + assert_raises(ValueError, coroutine.send, data[0]) + def test_arraysequence_getitem(self): # Get one item for i, e in enumerate(SEQ_DATA['seq']): diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index af733aa46b..e388c662e6 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -425,7 +425,7 @@ def save(self, fileobj): property_name = np.zeros(MAX_NB_NAMED_PROPERTIES_PER_STREAMLINE, dtype='S20') for i, name in enumerate(data_for_streamline_keys): - # Use the last to bytes of the name to store the number of + # Use the last two bytes of the name to store the number of # values associated to this data_for_streamline. nb_values = data_for_streamline[name].shape[-1] property_name[i] = encode_value_in_name(nb_values, name) From d62facc93008cf767360bb923dc50715dd3b6028 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 31 May 2016 14:05:05 -0400 Subject: [PATCH 112/135] Fixed typos --- nibabel/benchmarks/bench_streamlines.py | 45 ++++++--- nibabel/streamlines/array_sequence.py | 122 ++++++++++++------------ nibabel/streamlines/trk.py | 11 +-- 3 files changed, 93 insertions(+), 85 deletions(-) diff --git a/nibabel/benchmarks/bench_streamlines.py b/nibabel/benchmarks/bench_streamlines.py index 37c161e1d8..95ba79fb61 100644 --- a/nibabel/benchmarks/bench_streamlines.py +++ b/nibabel/benchmarks/bench_streamlines.py @@ -47,16 +47,20 @@ def bench_load_trk(): tractogram = Tractogram(points, affine_to_rasmm=np.eye(4)) TrkFile(tractogram).save(trk_file) - loaded_streamlines_old = [d[0]-0.5 for d in tv.read(trk_file, points_space="rasmm")[0]] + streamlines_old = [d[0] - 0.5 + for d in tv.read(trk_file, points_space="rasmm")[0]] mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat) - print("Old: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old)) - - loaded_streamlines_new = nib.streamlines.load(trk_file, lazy_load=False).streamlines - mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)', repeat) - print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new)) - print("Speedup of %2f" % (mtime_old/mtime_new)) - - for s1, s2 in zip(loaded_streamlines_new, loaded_streamlines_old): + print("Old: Loaded {:,} streamlines in {:6.2f}".format(NB_STREAMLINES, + mtime_old)) + + trk = nib.streamlines.load(trk_file, lazy_load=False) + streamlines_new = trk.streamlines + mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)', + repeat) + print("\nNew: Loaded {:,} streamlines in {:6.2}".format(NB_STREAMLINES, + mtime_new)) + print("Speedup of {:.2f}".format(mtime_old / mtime_new)) + for s1, s2 in zip(streamlines_new, streamlines_old): assert_array_equal(s1, s2) # Points and scalars @@ -68,13 +72,24 @@ def bench_load_trk(): affine_to_rasmm=np.eye(4)) TrkFile(tractogram).save(trk_file) - mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat) - print("Old: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old)) - - mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)', repeat) - print("New: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new)) - print("Speedup of %2f" % (mtime_old/mtime_new)) + streamlines_old = [d[0] - 0.5 + for d in tv.read(trk_file, points_space="rasmm")[0]] + scalars_old = [d[1] + for d in tv.read(trk_file, points_space="rasmm")[0]] + mtime_old = measure('tv.read(trk_file, points_space="rasmm")', repeat) + msg = "Old: Loaded {:,} streamlines with scalars in {:6.2f}" + print(msg.format(NB_STREAMLINES, mtime_old)) + + trk = nib.streamlines.load(trk_file, lazy_load=False) + scalars_new = trk.tractogram.data_per_point['scalars'] + mtime_new = measure('nib.streamlines.load(trk_file, lazy_load=False)', + repeat) + msg = "New: Loaded {:,} streamlines with scalars in {:6.2f}" + print(msg.format(NB_STREAMLINES, mtime_new)) + print("Speedup of {:2f}".format(mtime_old / mtime_new)) + for s1, s2 in zip(scalars_new, scalars_old): + assert_array_equal(s1, s2) if __name__ == '__main__': bench_load_trk() diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index d295e47d1b..4de989de1d 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -80,67 +80,6 @@ def __init__(self, iterable=None, buffer_size=4): coroutine.close() # Terminate coroutine. - def _extend_using_coroutine(self, buffer_size=4): - """ Creates a coroutine allowing to append elements. - - Parameters - ---------- - buffer_size : float, optional - Size (in Mb) for memory pre-allocation. - - Returns - ------- - coroutine - Coroutine object which expects the values to be appended to this - array sequence. - - Notes - ----- - This method is essential for - :func:`create_arraysequences_from_generator` as it allows for an - efficient way of creating multiple array sequences in a hyperthreaded - fashion and still benefit from the memory buffering. Whitout this - method the alternative would be to use :meth:`append` which does - not have such buffering mechanism and thus is at least one order of - magnitude slower. - """ - offsets = [] - lengths = [] - - offset = 0 if len(self) == 0 else self._offsets[-1] + self._lengths[-1] - try: - first_element = True - while True: - e = (yield) - e = np.asarray(e) - if first_element: - first_element = False - n_rows_buffer = int(buffer_size * 1024**2 // e.nbytes) - new_shape = (n_rows_buffer,) + e.shape[1:] - if len(self) == 0: - self._data = np.empty(new_shape, dtype=e.dtype) - - end = offset + len(e) - if end > len(self._data): - # Resize needed, adding `len(e)` items plus some buffer. - nb_points = len(self._data) - nb_points += len(e) + n_rows_buffer - self._data.resize((nb_points,) + self.common_shape) - - offsets.append(offset) - lengths.append(len(e)) - self._data[offset:offset + len(e)] = e - offset += len(e) - - except GeneratorExit: - pass - - self._offsets = np.concatenate([self._offsets, offsets], axis=0) - self._lengths = np.concatenate([self._lengths, lengths], axis=0) - - # Clear unused memory. - self._data.resize((offset,) + self.common_shape) - @property def is_array_sequence(self): return True @@ -238,6 +177,67 @@ def extend(self, elements): self._lengths = np.r_[self._lengths, elements._lengths] self._offsets = np.r_[self._offsets, offsets] + def _extend_using_coroutine(self, buffer_size=4): + """ Creates a coroutine allowing to append elements. + + Parameters + ---------- + buffer_size : float, optional + Size (in Mb) for memory pre-allocation. + + Returns + ------- + coroutine + Coroutine object which expects the values to be appended to this + array sequence. + + Notes + ----- + This method is essential for + :func:`create_arraysequences_from_generator` as it allows for an + efficient way of creating multiple array sequences in a hyperthreaded + fashion and still benefit from the memory buffering. Whitout this + method the alternative would be to use :meth:`append` which does + not have such buffering mechanism and thus is at least one order of + magnitude slower. + """ + offsets = [] + lengths = [] + + offset = 0 if len(self) == 0 else self._offsets[-1] + self._lengths[-1] + try: + first_element = True + while True: + e = (yield) + e = np.asarray(e) + if first_element: + first_element = False + n_rows_buffer = int(buffer_size * 1024**2 // e.nbytes) + new_shape = (n_rows_buffer,) + e.shape[1:] + if len(self) == 0: + self._data = np.empty(new_shape, dtype=e.dtype) + + end = offset + len(e) + if end > len(self._data): + # Resize needed, adding `len(e)` items plus some buffer. + nb_points = len(self._data) + nb_points += len(e) + n_rows_buffer + self._data.resize((nb_points,) + self.common_shape) + + offsets.append(offset) + lengths.append(len(e)) + self._data[offset:offset + len(e)] = e + offset += len(e) + + except GeneratorExit: + pass + + self._offsets = np.r_[self._offsets, offsets].astype(np.intp) + self._lengths = np.r_[self._lengths, lengths].astype(np.intp) + + # Clear unused memory. + self._data.resize((offset,) + self.common_shape) + def copy(self): """ Creates a copy of this :class:`ArraySequence` object. diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index e388c662e6..8449411711 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -601,8 +601,8 @@ def _read(fileobj, header): data : tuple of ndarrays Streamline data: points, scalars, properties. points: ndarray of shape (n_pts, 3) - scalars: ndarray of shape (n_pts, nb_scalars_per_per_point) - properties: ndarray of shape (nb_properties_per_per_point,) + scalars: ndarray of shape (n_pts, nb_scalars_per_point) + properties: ndarray of shape (nb_properties_per_point,) """ i4_dtype = np.dtype(header[Field.ENDIANNESS] + "i4") f4_dtype = np.dtype(header[Field.ENDIANNESS] + "f4") @@ -664,13 +664,6 @@ def _read(fileobj, header): def __str__(self): """ Gets a formatted string of the header of a TRK file. - Parameters - ---------- - fileobj : string or file-like object - If string, a filename; otherwise an open file-like object - pointing to TRK file (and ready to read from the beginning - of the header). - Returns ------- info : string From 2961bf5dd49327123732804f12a3f860849197d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 1 Jun 2016 14:06:23 -0400 Subject: [PATCH 113/135] Moved data for tests. --- nibabel/streamlines/tests/test_streamlines.py | 9 ++++----- nibabel/streamlines/tests/test_trk.py | 14 +++++++------- nibabel/{streamlines => }/tests/data/complex.trk | Bin .../tests/data/complex_big_endian.trk | Bin nibabel/{streamlines => }/tests/data/empty.trk | Bin .../{streamlines => }/tests/data/gen_standard.py | 0 nibabel/{streamlines => }/tests/data/simple.trk | Bin .../{streamlines => }/tests/data/standard.LPS.trk | Bin .../{streamlines => }/tests/data/standard.nii.gz | Bin nibabel/{streamlines => }/tests/data/standard.trk | Bin 10 files changed, 11 insertions(+), 12 deletions(-) rename nibabel/{streamlines => }/tests/data/complex.trk (100%) rename nibabel/{streamlines => }/tests/data/complex_big_endian.trk (100%) rename nibabel/{streamlines => }/tests/data/empty.trk (100%) rename nibabel/{streamlines => }/tests/data/gen_standard.py (100%) rename nibabel/{streamlines => }/tests/data/simple.trk (100%) rename nibabel/{streamlines => }/tests/data/standard.LPS.trk (100%) rename nibabel/{streamlines => }/tests/data/standard.nii.gz (100%) rename nibabel/{streamlines => }/tests/data/standard.trk (100%) diff --git a/nibabel/streamlines/tests/test_streamlines.py b/nibabel/streamlines/tests/test_streamlines.py index e814b3b620..c2f1c066d3 100644 --- a/nibabel/streamlines/tests/test_streamlines.py +++ b/nibabel/streamlines/tests/test_streamlines.py @@ -9,6 +9,7 @@ from nibabel.externals.six import BytesIO from nibabel.tmpdirs import InTemporaryDirectory +from nibabel.testing import data_path from nibabel.testing import clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true, assert_false @@ -17,18 +18,16 @@ from ..tractogram_file import TractogramFile, ExtensionWarning from .. import trk -DATA_PATH = pjoin(os.path.dirname(__file__), 'data') - DATA = {} def setup(): global DATA - DATA['empty_filenames'] = [pjoin(DATA_PATH, "empty" + ext) + DATA['empty_filenames'] = [pjoin(data_path, "empty" + ext) for ext in nib.streamlines.FORMATS.keys()] - DATA['simple_filenames'] = [pjoin(DATA_PATH, "simple" + ext) + DATA['simple_filenames'] = [pjoin(data_path, "simple" + ext) for ext in nib.streamlines.FORMATS.keys()] - DATA['complex_filenames'] = [pjoin(DATA_PATH, "complex" + ext) + DATA['complex_filenames'] = [pjoin(data_path, "complex" + ext) for ext in nib.streamlines.FORMATS.keys()] DATA['streamlines'] = [np.arange(1*3, dtype="f4").reshape((1, 3)), diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 87bceee36b..cb89ba3ec2 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -6,6 +6,7 @@ from nibabel.externals.six import BytesIO +from nibabel.testing import data_path from nibabel.testing import clear_and_catch_warnings from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal @@ -18,24 +19,23 @@ from ..trk import TrkFile from ..header import Field -DATA_PATH = pjoin(os.path.dirname(__file__), 'data') DATA = {} def setup(): global DATA - DATA['empty_trk_fname'] = pjoin(DATA_PATH, "empty.trk") + DATA['empty_trk_fname'] = pjoin(data_path, "empty.trk") # simple.trk contains only streamlines - DATA['simple_trk_fname'] = pjoin(DATA_PATH, "simple.trk") + DATA['simple_trk_fname'] = pjoin(data_path, "simple.trk") # standard.trk contains only streamlines - DATA['standard_trk_fname'] = pjoin(DATA_PATH, "standard.trk") + DATA['standard_trk_fname'] = pjoin(data_path, "standard.trk") # standard.LPS.trk contains only streamlines - DATA['standard_LPS_trk_fname'] = pjoin(DATA_PATH, "standard.LPS.trk") + DATA['standard_LPS_trk_fname'] = pjoin(data_path, "standard.LPS.trk") # complex.trk contains streamlines, scalars and properties - DATA['complex_trk_fname'] = pjoin(DATA_PATH, "complex.trk") - DATA['complex_trk_big_endian_fname'] = pjoin(DATA_PATH, + DATA['complex_trk_fname'] = pjoin(data_path, "complex.trk") + DATA['complex_trk_big_endian_fname'] = pjoin(data_path, "complex_big_endian.trk") DATA['streamlines'] = [np.arange(1*3, dtype="f4").reshape((1, 3)), diff --git a/nibabel/streamlines/tests/data/complex.trk b/nibabel/tests/data/complex.trk similarity index 100% rename from nibabel/streamlines/tests/data/complex.trk rename to nibabel/tests/data/complex.trk diff --git a/nibabel/streamlines/tests/data/complex_big_endian.trk b/nibabel/tests/data/complex_big_endian.trk similarity index 100% rename from nibabel/streamlines/tests/data/complex_big_endian.trk rename to nibabel/tests/data/complex_big_endian.trk diff --git a/nibabel/streamlines/tests/data/empty.trk b/nibabel/tests/data/empty.trk similarity index 100% rename from nibabel/streamlines/tests/data/empty.trk rename to nibabel/tests/data/empty.trk diff --git a/nibabel/streamlines/tests/data/gen_standard.py b/nibabel/tests/data/gen_standard.py similarity index 100% rename from nibabel/streamlines/tests/data/gen_standard.py rename to nibabel/tests/data/gen_standard.py diff --git a/nibabel/streamlines/tests/data/simple.trk b/nibabel/tests/data/simple.trk similarity index 100% rename from nibabel/streamlines/tests/data/simple.trk rename to nibabel/tests/data/simple.trk diff --git a/nibabel/streamlines/tests/data/standard.LPS.trk b/nibabel/tests/data/standard.LPS.trk similarity index 100% rename from nibabel/streamlines/tests/data/standard.LPS.trk rename to nibabel/tests/data/standard.LPS.trk diff --git a/nibabel/streamlines/tests/data/standard.nii.gz b/nibabel/tests/data/standard.nii.gz similarity index 100% rename from nibabel/streamlines/tests/data/standard.nii.gz rename to nibabel/tests/data/standard.nii.gz diff --git a/nibabel/streamlines/tests/data/standard.trk b/nibabel/tests/data/standard.trk similarity index 100% rename from nibabel/streamlines/tests/data/standard.trk rename to nibabel/tests/data/standard.trk From 3fc059309de303b256a463346b965181340d4a93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 1 Jun 2016 16:46:51 -0400 Subject: [PATCH 114/135] Followed @matthew-brett's suggestion for TRK header hack. --- .../streamlines/tests/test_array_sequence.py | 5 ++ nibabel/streamlines/trk.py | 84 +++++++++++------- nibabel/tests/data/complex.trk | Bin 1296 -> 1296 bytes nibabel/tests/data/complex_big_endian.trk | Bin 1296 -> 1296 bytes 4 files changed, 59 insertions(+), 30 deletions(-) diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index ed0ae84c05..5e38566b0b 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -280,6 +280,11 @@ def test_arraysequence_getitem(self): check_arr_seq_view(seq_view, SEQ_DATA['seq']) check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data']]) + # Combining multiple slicing and indexing operations. + seq_view = SEQ_DATA['seq'][::-2][:, 2] + check_arr_seq_view(seq_view, SEQ_DATA['seq']) + check_arr_seq(seq_view, [d[:, 2] for d in SEQ_DATA['data'][::-2]]) + def test_arraysequence_repr(self): # Test that calling repr on a ArraySequence object is not falling. repr(SEQ_DATA['seq']) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 8449411711..17ebce70a6 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -129,11 +129,11 @@ def get_affine_rasmm_to_trackvis(header): def encode_value_in_name(value, name, max_name_len=20): - """ Encodes a value in the last two bytes of a string. + """ Encodes a value in the last bytes of a string. - If `value` is one, then there is no encoding and the last two bytes - are left untouched. Otherwise, the byte before the last will be - set to \x00 and the last byte will correspond to the value. + If `value` is one, then there is no encoding and the last bytes + are left untouched. Otherwise, a \x00 byte is added after `name` + and followed by the ascii represensation of the value. This function also verifies that the length of name is less than `max_name_len`. @@ -157,20 +157,56 @@ def encode_value_in_name(value, name, max_name_len=20): msg = ("Data information named '{0}' is too long" " (max {1} characters.)").format(name, max_name_len) raise ValueError(msg) - elif len(name) > max_name_len - 2 and value > 1: + elif value > 1 and len(name) + len(str(value)) + 1 > max_name_len: msg = ("Data information named '{0}' is too long (need to be less" " than {1} characters when storing more than one value" " for a given data information." - ).format(name, max_name_len - 2) + ).format(name, max_name_len - (len(str(value)) + 1)) raise ValueError(msg) - name = name.ljust(max_name_len, '\x00') + encoded_name = name if value > 1: - # Use the last two bytes of `name` to store `value`. - name = (asbytes(name[:max_name_len - 2]) + b'\x00' + - np.array(value, dtype=np.int8).tostring()) + # Store the name followed by \x00 and the `value` (in ascii). + encoded_name += '\x00' + str(value) - return name + encoded_name = encoded_name.ljust(max_name_len, '\x00') + return encoded_name + + +def decode_value_from_name(encoded_name): + """ Decodes a value that has been encoded in the last bytes of a string. + + Check :func:`encode_value_in_name` to see how the value has been encoded. + + Parameters + ---------- + encoded_name : bytes + Name in which a value has been encoded or not. + + Returns + ------- + name : bytes + Name without the encoded value. + value : int + Value decoded from the name. + """ + encoded_name = asstr(encoded_name) + if len(encoded_name) == 0: + return encoded_name, 0 + + splits = encoded_name.rstrip('\x00').split('\x00') + name = splits[0] + value = 1 + + if len(splits) == 2: + value = int(splits[1]) # Decode value. + elif len(splits) > 2: + # The remaining bytes are not \x00, raising. + msg = ("Wrong scalar_name or property_name: '{}'." + " Unused characters should be \\x00.").format(name) + raise HeaderError(msg) + + return name, value def create_empty_header(): @@ -289,17 +325,11 @@ def load(cls, fileobj, lazy_load=False): if hdr[Field.NB_SCALARS_PER_POINT] > 0: cpt = 0 for scalar_name in hdr['scalar_name']: - scalar_name = asstr(scalar_name) - if len(scalar_name) == 0: - continue + scalar_name, nb_scalars = decode_value_from_name(scalar_name) - # Check if we encoded the number of values we stocked for this - # scalar name. - nb_scalars = 1 - if scalar_name[-2] == '\x00' and scalar_name[-1] != '\x00': - nb_scalars = int(np.fromstring(scalar_name[-1], np.int8)) + if nb_scalars == 0: + continue - scalar_name = scalar_name.split('\x00')[0] slice_obj = slice(cpt, cpt + nb_scalars) data_per_point_slice[scalar_name] = slice_obj cpt += nb_scalars @@ -312,18 +342,12 @@ def load(cls, fileobj, lazy_load=False): if hdr[Field.NB_PROPERTIES_PER_STREAMLINE] > 0: cpt = 0 for property_name in hdr['property_name']: - property_name = asstr(property_name) - if len(property_name) == 0: - continue + results = decode_value_from_name(property_name) + property_name, nb_properties = results - # Check if we encoded the number of values we stocked for this - # property name. - nb_properties = 1 - if property_name[-2] == '\x00' and property_name[-1] != '\x00': - nb_properties = int(np.fromstring(property_name[-1], - np.int8)) + if nb_properties == 0: + continue - property_name = property_name.split('\x00')[0] slice_obj = slice(cpt, cpt + nb_properties) data_per_streamline_slice[property_name] = slice_obj cpt += nb_properties diff --git a/nibabel/tests/data/complex.trk b/nibabel/tests/data/complex.trk index 0a874ea6e74a7646c163b437f5423443a4816105..e2860ee95afaa3f05feb766d49013cd265b840f3 100644 GIT binary patch delta 43 jcmbQhHGyk_F26AY5=cv&%*eE9;yXSLnDAy6MjjRbqUi?H delta 43 kcmbQhHGyk_E Date: Wed, 1 Jun 2016 17:08:25 -0400 Subject: [PATCH 115/135] Check we can add new elements to a ArraySequence loaded from disk --- nibabel/streamlines/tests/test_array_sequence.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 5e38566b0b..8275b2737a 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -323,3 +323,6 @@ def test_save_and_load_arraysequence(self): assert_array_equal(loaded_seq._data, seq._data) assert_array_equal(loaded_seq._offsets, seq._offsets) assert_array_equal(loaded_seq._lengths, seq._lengths) + + # Make sure we can add new elements to it. + loaded_seq.append(SEQ_DATA['data'][0]) From 075f5e76c1b554fec41f35dafc197027cb2784c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 1 Jun 2016 17:50:58 -0400 Subject: [PATCH 116/135] Added tests for decode_value function --- nibabel/streamlines/tests/test_trk.py | 11 +++++++++++ nibabel/streamlines/trk.py | 4 ++-- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index cb89ba3ec2..b691acc7c3 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -153,6 +153,17 @@ def test_load_file_with_wrong_information(self): new_trk_file = trk_file[:996] + hdr_size + trk_file[996+4:] assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + # Simulate a TRK file with a wrong scalar_name. + trk_file = open(DATA['complex_trk_fname'], 'rb').read() + noise = np.int32(42).tostring() + new_trk_file = trk_file[:47] + noise + trk_file[47+4:] + assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + + # Simulate a TRK file with a wrong property_name. + noise = np.int32(42).tostring() + new_trk_file = trk_file[:254] + noise + trk_file[254+4:] + assert_raises(HeaderError, TrkFile.load, BytesIO(new_trk_file)) + def test_load_complex_file_in_big_endian(self): trk_file = open(DATA['complex_trk_big_endian_fname'], 'rb').read() # We use hdr_size as an indicator of little vs big endian. diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 17ebce70a6..6a4f1da7b5 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -12,7 +12,7 @@ import nibabel as nib from nibabel.openers import Opener -from nibabel.py3k import asbytes, asstr +from nibabel.py3k import asstr from nibabel.volumeutils import (native_code, swapped_code) from nibabel.orientations import (aff2axcodes, axcodes2ornt) @@ -203,7 +203,7 @@ def decode_value_from_name(encoded_name): elif len(splits) > 2: # The remaining bytes are not \x00, raising. msg = ("Wrong scalar_name or property_name: '{}'." - " Unused characters should be \\x00.").format(name) + " Unused characters should be \\x00.").format(encoded_name) raise HeaderError(msg) return name, value From 18d3c325eeceb2d886fd1df1a2fe6c85a7fda55f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 1 Jun 2016 18:15:44 -0400 Subject: [PATCH 117/135] Supports Python2.6 --- nibabel/streamlines/trk.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 6a4f1da7b5..f310051cfe 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -202,7 +202,7 @@ def decode_value_from_name(encoded_name): value = int(splits[1]) # Decode value. elif len(splits) > 2: # The remaining bytes are not \x00, raising. - msg = ("Wrong scalar_name or property_name: '{}'." + msg = ("Wrong scalar_name or property_name: '{0}'." " Unused characters should be \\x00.").format(encoded_name) raise HeaderError(msg) From 6d38e8c9d49cea3a5e5f1ec811b46a50ad28825b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 1 Jun 2016 18:39:30 -0400 Subject: [PATCH 118/135] Added another test that will be failing until #https://github.com/MarcCote/nibabel/pull/6 is merged. --- nibabel/streamlines/tests/test_tractogram.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index 8e8f046358..d5c1e684e8 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -394,6 +394,16 @@ def test_tractogram_add_new_data(self): t.data_per_streamline[::-1], t.data_per_point[::-1]) + # Add new data to a tractogram for which its `streamlines` is a view. + t = Tractogram(DATA['streamlines']*2, affine_to_rasmm=np.eye(4)) + t = t[:len(DATA['streamlines'])] # Create a view of `streamlines` + t.data_per_point['fa'] = DATA['fa'] + t.data_per_point['colors'] = DATA['colors'] + t.data_per_streamline['mean_curvature'] = DATA['mean_curvature'] + t.data_per_streamline['mean_torsion'] = DATA['mean_torsion'] + t.data_per_streamline['mean_colors'] = DATA['mean_colors'] + assert_tractogram_equal(t, DATA['tractogram']) + def test_tractogram_copy(self): # Create a copy of a tractogram. tractogram = DATA['tractogram'].copy() From e3b4db54e6e9790db99e616dc5f99e3623e6b6b8 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Sat, 4 Jun 2016 12:11:57 -0700 Subject: [PATCH 119/135] RF: refactor to used cached append method Use caching of append parameters to speed up append for multiple passes. --- nibabel/streamlines/array_sequence.py | 190 ++++++++++++------ .../streamlines/tests/test_array_sequence.py | 24 ++- 2 files changed, 148 insertions(+), 66 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 4de989de1d..d02d3d27a9 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -1,6 +1,13 @@ +from __future__ import division + import numbers +from operator import mul +from functools import reduce + import numpy as np +MEGABYTE = 1024 * 1024 + def is_array_sequence(obj): """ Return True if `obj` is an array sequence. """ @@ -16,6 +23,26 @@ def is_ndarray_of_int_or_bool(obj): np.issubdtype(obj.dtype, np.bool))) +class _BuildCache(object): + def __init__(self, arr_seq, common_shape, dtype): + self.offsets = list(arr_seq._offsets) + self.lengths = list(arr_seq._lengths) + self.next_offset = arr_seq._get_next_offset() + self.bytes_per_buf = arr_seq._buffer_size * MEGABYTE + self.dtype = dtype + if arr_seq.common_shape != () and common_shape != arr_seq.common_shape: + raise ValueError( + "All dimensions, except the first one, must match exactly") + self.common_shape = common_shape + n_in_row = reduce(mul, common_shape, 1) + bytes_per_row = n_in_row * dtype.itemsize + self.rows_per_buf = bytes_per_row / self.bytes_per_buf + + def update_seq(self, arr_seq): + arr_seq._offsets = np.array(self.offsets) + arr_seq._lengths = np.array(self.lengths) + + class ArraySequence(object): """ Sequence of ndarrays having variable first dimension sizes. @@ -48,6 +75,8 @@ def __init__(self, iterable=None, buffer_size=4): self._data = np.array([]) self._offsets = np.array([], dtype=np.intp) self._lengths = np.array([], dtype=np.intp) + self._buffer_size = buffer_size + self._build_cache = None if iterable is None: return @@ -60,25 +89,24 @@ def __init__(self, iterable=None, buffer_size=4): self._is_view = True return + # If possible try pre-allocating memory. try: - # If possible try pre-allocating memory. - if len(iterable) > 0: - first_element = np.asarray(iterable[0]) - n_elements = np.sum([len(iterable[i]) - for i in range(len(iterable))]) - new_shape = (n_elements,) + first_element.shape[1:] - self._data = np.empty(new_shape, dtype=first_element.dtype) + iter_len = len(iterable) except TypeError: pass - - # Initialize the `ArraySequence` object from iterable's item. - coroutine = self._extend_using_coroutine() - coroutine.send(None) # Run until the first yield. + else: # We do know the iterable length + if iter_len == 0: + return + first_element = np.asarray(iterable[0]) + n_elements = np.sum([len(iterable[i]) + for i in range(len(iterable))]) + new_shape = (n_elements,) + first_element.shape[1:] + self._data = np.empty(new_shape, dtype=first_element.dtype) for e in iterable: - coroutine.send(e) + self.append(e, cache_build=True) - coroutine.close() # Terminate coroutine. + self.finalize_append() @property def is_array_sequence(self): @@ -92,21 +120,40 @@ def common_shape(self): @property def nb_elements(self): """ Total number of elements in this array sequence. """ - return self._data.shape[0] + return np.sum(self._lengths) @property def data(self): """ Elements in this array sequence. """ return self._data - def append(self, element): + def _get_next_offset(self): + """ Offset in ``self._data`` at which to write next element """ + if len(self._offsets) == 0: + return 0 + imax = np.argmax(self._offsets) + return self._offsets[imax] + self._lengths[imax] + + def append(self, element, cache_build=False): """ Appends `element` to this array sequence. + Append can be a lot faster if it knows that it is appending several + elements instead of a single element. In that case it can cache the + parameters it uses between append operations, in a "build cache". To + tell append to do this, use ``cache_build=True``. If you use + ``cache_build=True``, you need to finalize the append operations with + :method:`finalize_append`. + Parameters ---------- element : ndarray Element to append. The shape must match already inserted elements shape except for the first dimension. + cache_build : {False, True} + Whether to save the build cache from this append routine. If True, + append can assume it is the only player updating `self`, and the + caller must finalize `self` after all append operations, with + ``self.finalize_append()``. Returns ------- @@ -118,17 +165,56 @@ def append(self, element): `ArraySequence.extend`. """ element = np.asarray(element) + if element.size == 0: + return + el_shape = element.shape + n_items, common_shape = el_shape[0], el_shape[1:] + build_cache = self._build_cache + in_cached_build = build_cache is not None + if not in_cached_build: # One shot append, not part of sequence + build_cache = _BuildCache(self, common_shape, element.dtype) + next_offset = build_cache.next_offset + req_rows = next_offset + n_items + if self._data.shape[0] < req_rows: + self._resize_data_to(req_rows, build_cache) + self._data[next_offset:req_rows] = element + build_cache.offsets.append(next_offset) + build_cache.lengths.append(n_items) + build_cache.next_offset = req_rows + if in_cached_build: + return + if cache_build: + self._build_cache = build_cache + else: + build_cache.update_seq(self) - if self.common_shape != () and element.shape[1:] != self.common_shape: - msg = "All dimensions, except the first one, must match exactly" - raise ValueError(msg) + def finalize_append(self): + """ Finalize process of appending several elements to `self` - next_offset = self._data.shape[0] - size = (self._data.shape[0] + element.shape[0],) + element.shape[1:] - self._data.resize(size) - self._data[next_offset:] = element - self._offsets = np.r_[self._offsets, next_offset] - self._lengths = np.r_[self._lengths, element.shape[0]] + :method:`append` can be a lot faster if it knows that it is appending + several elements instead of a single element. To tell the append + method this is the case, use ``cache_build=True``. This method + finalizes the series of append operations after a call to + :method:`append` with ``cache_build=True``. + """ + if self._build_cache is None: + return + self._build_cache.update_seq(self) + self._build_cache = None + + def _resize_data_to(self, n_rows, build_cache): + """ Resize data array if required """ + # Calculate new data shape, rounding up to nearest buffer size + n_bufs = np.ceil(n_rows / build_cache.rows_per_buf) + extended_n_rows = int(n_bufs * build_cache.rows_per_buf) + new_shape = (extended_n_rows,) + build_cache.common_shape + if self._data.size == 0: + self._data = np.empty(new_shape, dtype=build_cache.dtype) + else: + self._data.resize(new_shape) + + def shrink_data(self): + self._data.resize((self._get_next_offset(),) + self.common_shape) def extend(self, elements): """ Appends all `elements` to this array sequence. @@ -154,28 +240,16 @@ def extend(self, elements): if not is_array_sequence(elements): self.extend(self.__class__(elements)) return - if len(elements) == 0: return - - if (self.common_shape != () and - elements.common_shape != self.common_shape): - msg = "All dimensions, except the first one, must match exactly" - raise ValueError(msg) - - next_offset = self._data.shape[0] - self._data.resize((self._data.shape[0] + sum(elements._lengths), - elements._data.shape[1])) - - offsets = [] - for offset, length in zip(elements._offsets, elements._lengths): - offsets.append(next_offset) - chunk = elements._data[offset:offset + length] - self._data[next_offset:next_offset + length] = chunk - next_offset += length - - self._lengths = np.r_[self._lengths, elements._lengths] - self._offsets = np.r_[self._offsets, offsets] + self._build_cache = _BuildCache(self, + elements.common_shape, + elements.data.dtype) + self._resize_data_to(self._get_next_offset() + elements.nb_elements, + self._build_cache) + for element in elements: + self.append(element) + self.finalize_append() def _extend_using_coroutine(self, buffer_size=4): """ Creates a coroutine allowing to append elements. @@ -204,7 +278,7 @@ def _extend_using_coroutine(self, buffer_size=4): offsets = [] lengths = [] - offset = 0 if len(self) == 0 else self._offsets[-1] + self._lengths[-1] + offset = self._get_next_offset() try: first_element = True while True: @@ -293,20 +367,24 @@ def __getitem__(self, idx): start = self._offsets[idx] return self._data[start:start + self._lengths[idx]] - elif isinstance(idx, (slice, list)) or is_ndarray_of_int_or_bool(idx): - seq = self.__class__() + seq = self.__class__() + seq._is_view = True + if isinstance(idx, tuple): + off_idx = idx[0] + seq._data = self._data.__getitem__((slice(None),) + idx[1:]) + else: + off_idx = idx seq._data = self._data - seq._offsets = self._offsets[idx] - seq._lengths = self._lengths[idx] - seq._is_view = True + + if isinstance(off_idx, slice): # Standard list slicing + seq._offsets = self._offsets[off_idx] + seq._lengths = self._lengths[off_idx] return seq - elif isinstance(idx, tuple): - seq = self.__class__() - seq._data = self._data.__getitem__((slice(None),) + idx[1:]) - seq._offsets = self._offsets[idx[0]] - seq._lengths = self._lengths[idx[0]] - seq._is_view = True + if isinstance(off_idx, list) or is_ndarray_of_int_or_bool(off_idx): + # Fancy indexing + seq._offsets = self._offsets[off_idx] + seq._lengths = self._lengths[off_idx] return seq raise TypeError("Index must be either an int, a slice, a list of int" diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index ed0ae84c05..9350920943 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -52,6 +52,7 @@ def check_arr_seq(seq, arrays): # The only thing we can check is the _lengths. assert_array_equal(sorted(seq._lengths), sorted(lengths)) else: + seq.shrink_data() assert_equal(seq._data.shape[0], sum(lengths)) assert_array_equal(seq._data, np.concatenate(arrays, axis=0)) assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]]) @@ -113,20 +114,23 @@ def test_arraysequence_iter(self): assert_raises(ValueError, list, seq) def test_arraysequence_copy(self): - seq = SEQ_DATA['seq'].copy() - assert_array_equal(seq._data, SEQ_DATA['seq']._data) - assert_true(seq._data is not SEQ_DATA['seq']._data) - assert_array_equal(seq._offsets, SEQ_DATA['seq']._offsets) - assert_true(seq._offsets is not SEQ_DATA['seq']._offsets) - assert_array_equal(seq._lengths, SEQ_DATA['seq']._lengths) - assert_true(seq._lengths is not SEQ_DATA['seq']._lengths) - assert_equal(seq.common_shape, SEQ_DATA['seq'].common_shape) + orig = SEQ_DATA['seq'] + seq = orig.copy() + n_rows = seq.nb_elements + assert_equal(n_rows, orig.nb_elements) + assert_array_equal(seq._data, orig._data[:n_rows]) + assert_true(seq._data is not orig._data) + assert_array_equal(seq._offsets, orig._offsets) + assert_true(seq._offsets is not orig._offsets) + assert_array_equal(seq._lengths, orig._lengths) + assert_true(seq._lengths is not orig._lengths) + assert_equal(seq.common_shape, orig.common_shape) # Taking a copy of an `ArraySequence` generated by slicing. # Only keep needed data. - seq = SEQ_DATA['seq'][::2].copy() + seq = orig[::2].copy() check_arr_seq(seq, SEQ_DATA['data'][::2]) - assert_true(seq._data is not SEQ_DATA['seq']._data) + assert_true(seq._data is not orig._data) def test_arraysequence_append(self): element = generate_data(nb_arrays=1, From 6ff7e6436c07d32c188d122147eea88a048b7083 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Sat, 4 Jun 2016 12:39:43 -0700 Subject: [PATCH 120/135] RF: use new cached append for multi-seq read Use the new cached append to replace the coroutines for the read across multiple array sequences in ``create_arraysequences_from_generator``. Before this change (using coroutines): ``` Old: Loaded 5,000 streamlines in 2.85 New: Loaded 5,000 streamlines in 3.6 Speedup of 0.79 Old: Loaded 5,000 streamlines with scalars in 5.16 New: Loaded 5,000 streamlines with scalars in 7.13 Speedup of 0.723703 ``` After this change (using cached append): ``` Old: Loaded 5,000 streamlines in 3.21 New: Loaded 5,000 streamlines in 3.9 Speedup of 0.82 Old: Loaded 5,000 streamlines with scalars in 5.21 New: Loaded 5,000 streamlines with scalars in 7.16 Speedup of 0.727654 ``` This seems to be well within run-to-run measurement error. --- nibabel/streamlines/array_sequence.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index d02d3d27a9..8d907197c4 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -444,17 +444,11 @@ def create_arraysequences_from_generator(gen, n): Number of :class:`ArraySequences` object to create. """ seqs = [ArraySequence() for _ in range(n)] - coroutines = [seq._extend_using_coroutine() for seq in seqs] - - for coroutine in coroutines: - coroutine.send(None) - for data in gen: - for i, coroutine in enumerate(coroutines): + for i, seq in enumerate(seqs): if data[i].nbytes > 0: - coroutine.send(data[i]) - - for coroutine in coroutines: - coroutine.close() + seq.append(data[i], cache_build=True) + for seq in seqs: + seq.finalize_append() return seqs From acfbf39221a73497c886f6717d37fc94b7c3f3b0 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Sat, 4 Jun 2016 20:23:08 -0700 Subject: [PATCH 121/135] RF: refactor init + extend for arraysequence Use extend in init; use trick to save memory allocations in extend routine. --- nibabel/streamlines/array_sequence.py | 51 +++++++++++---------------- 1 file changed, 20 insertions(+), 31 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 8d907197c4..ce96f6c7ce 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -29,7 +29,8 @@ def __init__(self, arr_seq, common_shape, dtype): self.lengths = list(arr_seq._lengths) self.next_offset = arr_seq._get_next_offset() self.bytes_per_buf = arr_seq._buffer_size * MEGABYTE - self.dtype = dtype + # Use the passed dtype only if null data array + self.dtype = dtype if arr_seq._data.size == 0 else arr_seq._data.dtype if arr_seq.common_shape != () and common_shape != arr_seq.common_shape: raise ValueError( "All dimensions, except the first one, must match exactly") @@ -89,24 +90,7 @@ def __init__(self, iterable=None, buffer_size=4): self._is_view = True return - # If possible try pre-allocating memory. - try: - iter_len = len(iterable) - except TypeError: - pass - else: # We do know the iterable length - if iter_len == 0: - return - first_element = np.asarray(iterable[0]) - n_elements = np.sum([len(iterable[i]) - for i in range(len(iterable))]) - new_shape = (n_elements,) + first_element.shape[1:] - self._data = np.empty(new_shape, dtype=first_element.dtype) - - for e in iterable: - self.append(e, cache_build=True) - - self.finalize_append() + self.extend(iterable) @property def is_array_sequence(self): @@ -237,18 +221,23 @@ def extend(self, elements): The shape of the elements to be added must match the one of the data of this :class:`ArraySequence` except for the first dimension. """ - if not is_array_sequence(elements): - self.extend(self.__class__(elements)) - return - if len(elements) == 0: - return - self._build_cache = _BuildCache(self, - elements.common_shape, - elements.data.dtype) - self._resize_data_to(self._get_next_offset() + elements.nb_elements, - self._build_cache) - for element in elements: - self.append(element) + # If possible try pre-allocating memory. + try: + iter_len = len(elements) + except TypeError: + pass + else: # We do know the iterable length + if iter_len == 0: + return + e0 = np.asarray(elements[0]) + n_elements = np.sum([len(e) for e in elements]) + self._build_cache = _BuildCache(self, e0.shape[1:], e0.dtype) + self._resize_data_to(self._get_next_offset() + n_elements, + self._build_cache) + + for e in elements: + self.append(e, cache_build=True) + self.finalize_append() def _extend_using_coroutine(self, buffer_size=4): From a9449576f6dffa25d88030ab7b5b7aadc03ef89c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Mon, 6 Jun 2016 08:23:26 -0400 Subject: [PATCH 122/135] Removed _extend_using_coroutines --- nibabel/streamlines/array_sequence.py | 61 ------------------- .../streamlines/tests/test_array_sequence.py | 31 ---------- 2 files changed, 92 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index ce96f6c7ce..ce36f83d47 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -240,67 +240,6 @@ def extend(self, elements): self.finalize_append() - def _extend_using_coroutine(self, buffer_size=4): - """ Creates a coroutine allowing to append elements. - - Parameters - ---------- - buffer_size : float, optional - Size (in Mb) for memory pre-allocation. - - Returns - ------- - coroutine - Coroutine object which expects the values to be appended to this - array sequence. - - Notes - ----- - This method is essential for - :func:`create_arraysequences_from_generator` as it allows for an - efficient way of creating multiple array sequences in a hyperthreaded - fashion and still benefit from the memory buffering. Whitout this - method the alternative would be to use :meth:`append` which does - not have such buffering mechanism and thus is at least one order of - magnitude slower. - """ - offsets = [] - lengths = [] - - offset = self._get_next_offset() - try: - first_element = True - while True: - e = (yield) - e = np.asarray(e) - if first_element: - first_element = False - n_rows_buffer = int(buffer_size * 1024**2 // e.nbytes) - new_shape = (n_rows_buffer,) + e.shape[1:] - if len(self) == 0: - self._data = np.empty(new_shape, dtype=e.dtype) - - end = offset + len(e) - if end > len(self._data): - # Resize needed, adding `len(e)` items plus some buffer. - nb_points = len(self._data) - nb_points += len(e) + n_rows_buffer - self._data.resize((nb_points,) + self.common_shape) - - offsets.append(offset) - lengths.append(len(e)) - self._data[offset:offset + len(e)] = e - offset += len(e) - - except GeneratorExit: - pass - - self._offsets = np.r_[self._offsets, offsets].astype(np.intp) - self._lengths = np.r_[self._lengths, lengths].astype(np.intp) - - # Clear unused memory. - self._data.resize((offset,) + self.common_shape) - def copy(self): """ Creates a copy of this :class:`ArraySequence` object. diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 9424be9b3a..4f52c5d987 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -204,37 +204,6 @@ def test_arraysequence_extend(self): seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. assert_raises(ValueError, seq.extend, data) - def test_arraysequence_extend_using_coroutine(self): - new_data = generate_data(nb_arrays=10, - common_shape=SEQ_DATA['seq'].common_shape, - rng=SEQ_DATA['rng']) - - # Extend with an empty list. - seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. - coroutine = seq._extend_using_coroutine() - coroutine.send(None) - coroutine.close() - check_arr_seq(seq, SEQ_DATA['data']) - - # Extend with a list of ndarrays. - seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. - coroutine = seq._extend_using_coroutine() - coroutine.send(None) - for e in new_data: - coroutine.send(e) - coroutine.close() - check_arr_seq(seq, SEQ_DATA['data'] + new_data) - - # Extend with elements of different shape. - data = generate_data(nb_arrays=10, - common_shape=SEQ_DATA['seq'].common_shape*2, - rng=SEQ_DATA['rng']) - seq = SEQ_DATA['seq'].copy() # Copy because of in-place modification. - - coroutine = seq._extend_using_coroutine() - coroutine.send(None) - assert_raises(ValueError, coroutine.send, data[0]) - def test_arraysequence_getitem(self): # Get one item for i, e in enumerate(SEQ_DATA['seq']): From 649eff38b51f4f3944b798d94350f962ff8b8e54 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 6 Jun 2016 15:41:20 -0700 Subject: [PATCH 123/135] RF: small streamlines API PR edits A little refactoring of the Trk file ``_read_header`` method to save an unnecessary seek. Some rewriting and expansion of docstrings. --- nibabel/streamlines/trk.py | 30 ++++++++++++++++++------------ nibabel/tests/data/gen_standard.py | 7 +++++++ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index f310051cfe..4587a36ac1 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -541,10 +541,12 @@ def _read_header(fileobj): Returns ------- header : dict - Metadata associated to this tractogram file. + Metadata associated with this tractogram file. """ + # Record start position if this is a file-like object + start_position = fileobj.tell() if hasattr(fileobj, 'tell') else None + with Opener(fileobj) as f: - start_position = f.tell() # Read the header in one block. header_str = f.read(header_2_dtype.itemsize) @@ -568,7 +570,8 @@ def _read_header(fileobj): elif header_rec['version'] == 2: pass # Nothing more to do. else: - raise HeaderError('NiBabel only supports versions 1 and 2.') + raise HeaderError('NiBabel only supports versions 1 and 2 of ' + 'the Trackvis file format') # Convert the first record of `header_rec` into a dictionnary header = dict(zip(header_rec.dtype.names, header_rec[0])) @@ -601,14 +604,15 @@ def _read_header(fileobj): # Keep the file position where the data begin. header['_offset_data'] = f.tell() - # Set the file position where it was (in case it was already open). - f.seek(start_position, os.SEEK_CUR) + # Set the file position where it was, if it was previously open + if start_position is not None: + fileobj.seek(start_position, os.SEEK_CUR) - return header + return header @staticmethod def _read(fileobj, header): - """ Reads TRK data from a file. + """ Return generator that reads TRK data from `fileobj` given `header` Parameters ---------- @@ -618,15 +622,17 @@ def _read(fileobj, header): of the TRK header). Note that calling this function does not change the file position. header : dict - Metadata associated to this tractogram file. + Metadata associated with this tractogram file. Yields ------ data : tuple of ndarrays - Streamline data: points, scalars, properties. - points: ndarray of shape (n_pts, 3) - scalars: ndarray of shape (n_pts, nb_scalars_per_point) - properties: ndarray of shape (nb_properties_per_point,) + Length 3 tuple of streamline data of form (points, scalars, + properties), where: + + * points: ndarray of shape (n_pts, 3) + * scalars: ndarray of shape (n_pts, nb_scalars_per_point) + * properties: ndarray of shape (nb_properties_per_point,) """ i4_dtype = np.dtype(header[Field.ENDIANNESS] + "i4") f4_dtype = np.dtype(header[Field.ENDIANNESS] + "f4") diff --git a/nibabel/tests/data/gen_standard.py b/nibabel/tests/data/gen_standard.py index 63b4173602..b97da8ff2f 100644 --- a/nibabel/tests/data/gen_standard.py +++ b/nibabel/tests/data/gen_standard.py @@ -1,3 +1,10 @@ +""" Generate mask and testing tractogram in known formats: + +* mask: standard.nii.gz +* tractogram: + + * standard.trk +""" import numpy as np import nibabel as nib From 5fc5dd943dc3b088793aa71c1a09d07982470796 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 6 Jun 2016 16:15:09 -0700 Subject: [PATCH 124/135] RF: move assert_arr_dict_equal to testing Move testing function from parrect tests to general testing utilities. --- nibabel/testing/__init__.py | 9 +++++++++ nibabel/tests/test_parrec.py | 10 ++-------- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/nibabel/testing/__init__.py b/nibabel/testing/__init__.py index 8e0cd982e5..2200b25182 100644 --- a/nibabel/testing/__init__.py +++ b/nibabel/testing/__init__.py @@ -200,3 +200,12 @@ def runif_extra_has(test_str): """Decorator checks to see if NIPY_EXTRA_TESTS env var contains test_str""" return skipif(test_str not in EXTRA_SET, "Skip {0} tests.".format(test_str)) + + +def assert_arr_dict_equal(dict1, dict2): + """ Assert that two dicts are equal, where dicts contain arrays + """ + assert_equal(set(dict1), set(dict2)) + for key, value1 in dict1.items(): + value2 = dict2[key] + assert_array_equal(value1, value2) diff --git a/nibabel/tests/test_parrec.py b/nibabel/tests/test_parrec.py index 63e96c4938..ed50150706 100644 --- a/nibabel/tests/test_parrec.py +++ b/nibabel/tests/test_parrec.py @@ -23,7 +23,8 @@ from nose.tools import (assert_true, assert_false, assert_raises, assert_equal) -from ..testing import clear_and_catch_warnings, suppress_warnings +from ..testing import (clear_and_catch_warnings, suppress_warnings, + assert_arr_dict_equal) from .test_arrayproxy import check_mmap from . import test_spatialimages as tsi @@ -618,13 +619,6 @@ def test_copy_on_init(): assert_array_equal(HDR_DEFS['image pixel size'], 16) -def assert_arr_dict_equal(dict1, dict2): - assert_equal(set(dict1), set(dict2)) - for key, value1 in dict1.items(): - value2 = dict2[key] - assert_array_equal(value1, value2) - - def assert_structarr_equal(star1, star2): # Compare structured arrays (array_equal does not work for np 1.5) assert_equal(star1.dtype, star2.dtype) From bf8b6b5a319464d25b72aac4de8e3ad026e4d13e Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 6 Jun 2016 16:36:39 -0700 Subject: [PATCH 125/135] BF+TST: test and fix header file position reset Test reset of file position when reading trackvis header. This revealed that we were using the wrong flag to `seek`. --- nibabel/streamlines/tests/test_trk.py | 18 +++++++++++++++++- nibabel/streamlines/trk.py | 2 +- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index b691acc7c3..96f541a65d 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -7,7 +7,7 @@ from nibabel.externals.six import BytesIO from nibabel.testing import data_path -from nibabel.testing import clear_and_catch_warnings +from nibabel.testing import clear_and_catch_warnings, assert_arr_dict_equal from nose.tools import assert_equal, assert_raises, assert_true from numpy.testing import assert_array_equal @@ -460,3 +460,19 @@ def test_write_scalars_and_properties_name_too_long(self): def test_str(self): trk = TrkFile.load(DATA['complex_trk_fname']) str(trk) # Simply test it's not failing when called. + + def test_header_read_restore(self): + # Test that reading a header restores the file position + trk_fname = DATA['simple_trk_fname'] + bio = BytesIO() + bio.write(b'Along my very merry way') + hdr_pos = bio.tell() + hdr_from_fname = TrkFile._read_header(trk_fname) + with open(trk_fname, 'rb') as fobj: + bio.write(fobj.read()) + bio.seek(hdr_pos) + # Check header is as expected + hdr_from_fname['_offset_data'] += hdr_pos # Correct for start position + assert_arr_dict_equal(TrkFile._read_header(bio), hdr_from_fname) + # Check fileobject file position has not changed + assert_equal(bio.tell(), hdr_pos) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 4587a36ac1..276144fdd7 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -606,7 +606,7 @@ def _read_header(fileobj): # Set the file position where it was, if it was previously open if start_position is not None: - fileobj.seek(start_position, os.SEEK_CUR) + fileobj.seek(start_position, os.SEEK_SET) return header From 959a1c76b0c6a25b9c1a6fa2c887210fc32cf338 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 6 Jun 2016 17:14:21 -0700 Subject: [PATCH 126/135] RF: use generic testing function to hdr1=hdr2 Use ``assert_arr_dict_equal`` to test whether two headers are equal, rather than custom testing function. --- nibabel/streamlines/tests/test_trk.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 96f541a65d..587f8ac752 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -82,14 +82,6 @@ def setup(): affine_to_rasmm=np.eye(4)) -def assert_header_equal(h1, h2): - for k in h1.keys(): - assert_array_equal(h2[k], h1[k]) - - for k in h2.keys(): - assert_array_equal(h1[k], h2[k]) - - class TestTRK(unittest.TestCase): def test_load_empty_file(self): @@ -298,7 +290,7 @@ def test_load_write_LPS_file(self): new_trk = TrkFile.load(trk_file) - assert_header_equal(new_trk.header, trk.header) + assert_arr_dict_equal(new_trk.header, trk.header) assert_tractogram_equal(new_trk.tractogram, trk.tractogram) new_trk_orig = TrkFile.load(DATA['standard_LPS_trk_fname']) @@ -322,7 +314,7 @@ def test_load_write_LPS_file(self): new_trk = TrkFile.load(trk_file) - assert_header_equal(new_trk.header, trk_LPS.header) + assert_arr_dict_equal(new_trk.header, trk_LPS.header) assert_tractogram_equal(new_trk.tractogram, trk.tractogram) new_trk_orig = TrkFile.load(DATA['standard_LPS_trk_fname']) From 84dd196a1a9f5e1d0cf67e6bc8c696f3d70822ee Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 6 Jun 2016 17:41:14 -0700 Subject: [PATCH 127/135] RF+TST: add tests for en/decoding value in fields Add tests for encoding / decoding numerical values in byte string fields. Refactor encoding. Update comments / docstrings to note new encoding with numbers as ASCII strings. --- nibabel/streamlines/tests/test_trk.py | 30 ++++++++++++++- nibabel/streamlines/trk.py | 53 +++++++++++++-------------- 2 files changed, 55 insertions(+), 28 deletions(-) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 587f8ac752..09cd93d404 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -16,7 +16,7 @@ from ..tractogram_file import HeaderError, HeaderWarning from .. import trk as trk_module -from ..trk import TrkFile +from ..trk import TrkFile, encode_value_in_name, decode_value_from_name from ..header import Field DATA = {} @@ -468,3 +468,31 @@ def test_header_read_restore(self): assert_arr_dict_equal(TrkFile._read_header(bio), hdr_from_fname) # Check fileobject file position has not changed assert_equal(bio.tell(), hdr_pos) + + +def test_encode_names(): + # Test function for encoding numbers into property names + b0 = b'\x00' + assert_equal(encode_value_in_name(0, 'foo', 10), + b'foo' + b0 * 7) + assert_equal(encode_value_in_name(1, 'foo', 10), + b'foo' + b0 * 7) + assert_equal(encode_value_in_name(8, 'foo', 10), + b'foo' + b0 + b'8' + b0 * 5) + assert_equal(encode_value_in_name(40, 'foobar', 10), + b'foobar' + b0 + b'40' + b0) + assert_equal(encode_value_in_name(1, 'foobarbazz', 10), b'foobarbazz') + assert_raises(ValueError, encode_value_in_name, 1, 'foobarbazzz', 10) + assert_raises(ValueError, encode_value_in_name, 2, 'foobarbaz', 10) + assert_equal(encode_value_in_name(2, 'foobarba', 10), b'foobarba\x002') + + +def test_decode_names(): + # Test function for decoding name string into name, number + b0 = b'\x00' + assert_equal(decode_value_from_name(b''), ('', 0)) + assert_equal(decode_value_from_name(b'foo' + b0 * 7), ('foo', 1)) + assert_equal(decode_value_from_name(b'foo\x008' + b0 * 5), ('foo', 8)) + assert_equal(decode_value_from_name(b'foobar\x0010\x00'), ('foobar', 10)) + assert_raises(ValueError, decode_value_from_name, b'foobar\x0010\x01') + assert_raises(HeaderError, decode_value_from_name, b'foo\x0010\x00111') diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 276144fdd7..bd18f3d4fa 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -129,48 +129,46 @@ def get_affine_rasmm_to_trackvis(header): def encode_value_in_name(value, name, max_name_len=20): - """ Encodes a value in the last bytes of a string. + """ Return `name` as fixed-length string, appending `value` as string. - If `value` is one, then there is no encoding and the last bytes - are left untouched. Otherwise, a \x00 byte is added after `name` - and followed by the ascii represensation of the value. + Form output from `name` if `value <= 1` else `name` + ``\x00`` + + str(value). - This function also verifies that the length of name is less - than `max_name_len`. + Return output as fixed length string length `max_name_len`, padded with + ``\x00``. + + This function also verifies that the modified length of name is less than + `max_name_len`. Parameters ---------- - value : byte - Integer value between 0 and 255 to encode. - name : bytes - Name in which the last two bytes will serve to encode `value`. + value : int + Integer value to encode. + name : str + Name to which we may append an ascii / latin-1 representation of + `value`. max_name_len : int, optional - Maximum length name can have. + Maximum length of byte string that output can have. Returns ------- encoded_name : bytes - Name containing the encoded value. + Name maybe followed by ``\x00`` and ascii / latin-1 representation of + `value`, padded with ``\x00`` bytes. """ - if len(name) > max_name_len: msg = ("Data information named '{0}' is too long" " (max {1} characters.)").format(name, max_name_len) raise ValueError(msg) - elif value > 1 and len(name) + len(str(value)) + 1 > max_name_len: + encoded_name = name if value <= 1 else name + '\x00' + str(value) + if len(encoded_name) > max_name_len: msg = ("Data information named '{0}' is too long (need to be less" " than {1} characters when storing more than one value" " for a given data information." ).format(name, max_name_len - (len(str(value)) + 1)) raise ValueError(msg) - - encoded_name = name - if value > 1: - # Store the name followed by \x00 and the `value` (in ascii). - encoded_name += '\x00' + str(value) - - encoded_name = encoded_name.ljust(max_name_len, '\x00') - return encoded_name + # Fill to the end with zeros + return encoded_name.ljust(max_name_len, '\x00').encode('latin1') def decode_value_from_name(encoded_name): @@ -388,7 +386,7 @@ def _read(): return cls(tractogram, header=hdr) def save(self, fileobj): - """ Saves tractogram to a file-like object using TRK format. + """ Save tractogram to a filename or file-like object using TRK format. Parameters ---------- @@ -420,6 +418,7 @@ def save(self, fileobj): # Keep track of the beginning of the header. beginning = f.tell() + # Write temporary header that we will update at the end f.write(header.tostring()) i4_dtype = np.dtype(" Date: Mon, 6 Jun 2016 17:50:58 -0700 Subject: [PATCH 128/135] DOC: add docstring for function giving affine Add docstring for function giving RAS+ affine from trackvis header. --- nibabel/streamlines/trk.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index bd18f3d4fa..dc21316293 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -88,6 +88,24 @@ def get_affine_trackvis_to_rasmm(header): + """ Get affine mapping trackvis voxelmm space to RAS+ mm space + + The streamlines in a trackvis file are in 'voxelmm' space, where the + coordinates refer to the corner of the voxel. + + Compute the # affine matrix that will bring them back to RAS+ mm space, + where the coordinates refer to the center of the voxel. + + Parameters + ---------- + header : dict + Dict containing trackvis header. + + Returns + ------- + aff_tv2ras : shape (4, 4) array + Affine array mapping coordinates in 'voxelmm' space to RAS+ mm space. + """ # TRK's streamlines are in 'voxelmm' space, we will compute the # affine matrix that will bring them back to RAS+ and mm space. affine = np.eye(4) From 16acde396a55344af316a0b78ae03e469c793908 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 6 Jun 2016 18:00:00 -0700 Subject: [PATCH 129/135] DOC: small edits to tractogram_file docstrings --- nibabel/streamlines/tractogram_file.py | 6 ++++-- nibabel/streamlines/trk.py | 4 ++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index b2c7ef0018..1b2a2049da 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -1,3 +1,5 @@ +""" Define abstract interface for Tractogram file classes +""" from abc import ABCMeta, abstractmethod from nibabel.externals.six import with_metaclass @@ -85,7 +87,7 @@ def is_correct_format(cls, fileobj): @abstractclassmethod def load(cls, fileobj, lazy_load=True): - """ Loads streamlines from a file-like object. + """ Loads streamlines from a filename or file-like object. Parameters ---------- @@ -107,7 +109,7 @@ def load(cls, fileobj, lazy_load=True): @abstractmethod def save(self, fileobj): - """ Saves streamlines to a file-like object. + """ Saves streamlines to a filename or file-like object. Parameters ---------- diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index dc21316293..4f7ea7a8eb 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -309,7 +309,7 @@ def is_correct_format(cls, fileobj): @classmethod def load(cls, fileobj, lazy_load=False): - """ Loads streamlines from a file-like object. + """ Loads streamlines from a filename or file-like object. Parameters ---------- @@ -410,7 +410,7 @@ def save(self, fileobj): ---------- fileobj : string or file-like object If string, a filename; otherwise an open file-like object - pointing to TRK file (and ready to read from the beginning + pointing to TRK file (and ready to write from the beginning of the TRK header data). """ header = create_empty_header() From d910a5eb18df5ce17e3e26e3c636ac206836ccfb Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 6 Jun 2016 18:05:12 -0700 Subject: [PATCH 130/135] RF: remove duplicate methods for attribute access We discussed this at https://github.com/nipy/nibabel/pull/391/files#r57834220 --- nibabel/streamlines/tests/test_trk.py | 5 +---- nibabel/streamlines/tractogram_file.py | 14 +------------- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/nibabel/streamlines/tests/test_trk.py b/nibabel/streamlines/tests/test_trk.py index 09cd93d404..f890021689 100644 --- a/nibabel/streamlines/tests/test_trk.py +++ b/nibabel/streamlines/tests/test_trk.py @@ -170,10 +170,7 @@ def test_load_complex_file_in_big_endian(self): def test_tractogram_file_properties(self): trk = TrkFile.load(DATA['simple_trk_fname']) assert_equal(trk.streamlines, trk.tractogram.streamlines) - assert_equal(trk.get_streamlines(), trk.streamlines) - assert_equal(trk.get_tractogram(), trk.tractogram) - assert_equal(trk.get_header(), trk.header) - assert_array_equal(trk.get_affine(), trk.header[Field.VOXEL_TO_RASMM]) + assert_array_equal(trk.affine, trk.header[Field.VOXEL_TO_RASMM]) def test_write_empty_file(self): tractogram = Tractogram(affine_to_rasmm=np.eye(4)) diff --git a/nibabel/streamlines/tractogram_file.py b/nibabel/streamlines/tractogram_file.py index 1b2a2049da..a1dc4e83fb 100644 --- a/nibabel/streamlines/tractogram_file.py +++ b/nibabel/streamlines/tractogram_file.py @@ -51,21 +51,9 @@ def header(self): @property def affine(self): + """ voxmm -> rasmm affine. """ return self.header.get(Field.VOXEL_TO_RASMM) - def get_tractogram(self): - return self.tractogram - - def get_streamlines(self): - return self.streamlines - - def get_header(self): - return self.header - - def get_affine(self): - """ Returns vox -> rasmm affine. """ - return self.affine - @abstractclassmethod def is_correct_format(cls, fileobj): """ Checks if the file has the right streamlines file format. From d23100418ba1080b8c2b7b129d7c397187304183 Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 6 Jun 2016 19:03:38 -0700 Subject: [PATCH 131/135] DOC: edits to docstrings --- nibabel/streamlines/tractogram.py | 77 +++++++++++++++---------------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 754ef69aad..5a8d65b3ba 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -80,7 +80,7 @@ class PerArrayDict(SliceableDataDict): In addition, it makes sure the amount of data contained in those ndarrays matches the number of streamlines given at the instantiation of this - dictionary. + instance. """ def __init__(self, nb_elements, *args, **kwargs): self.nb_elements = nb_elements @@ -114,7 +114,7 @@ class PerArraySequenceDict(SliceableDataDict): In addition, it makes sure the amount of data contained in those array sequences matches the number of elements given at the instantiation - of this dictionary. + of the instance. """ def __init__(self, nb_elements, *args, **kwargs): self.nb_elements = nb_elements @@ -136,9 +136,9 @@ def __setitem__(self, key, value): class LazyDict(collections.MutableMapping): """ Dictionary of generator functions. - This container behaves like an dictionary but it makes sure its elements - are callable objects and assumed to be generator function yielding values. - When getting the element associated to a given key, the element (i.e. a + This container behaves like a dictionary but it makes sure its elements are + callable objects that it assumes are generator functions yielding values. + When getting the element associated with a given key, the element (i.e. a generator function) is first called before being returned. """ def __init__(self, *args, **kwargs): @@ -178,7 +178,7 @@ def __len__(self): class TractogramItem(object): """ Class containing information about one streamline. - :class:`TractogramItem` objects have three main properties: `streamline`, + :class:`TractogramItem` objects have three public attributes: `streamline`, `data_for_streamline`, and `data_for_points`. Parameters @@ -187,14 +187,14 @@ class TractogramItem(object): Points of this streamline represented as an ndarray of shape (N, 3) where N is the number of points. data_for_streamline : dict - Dictionary containing some data associated to this particular - streamline. Each key `k` is mapped to a ndarray of shape (Pt,), where - `Pt` is the dimension of the data associated with key `k`. + Dictionary containing some data associated with this particular + streamline. Each key ``k`` is mapped to a ndarray of shape (Pt,), where + ``Pt`` is the dimension of the data associated with key ``k``. data_for_points : dict Dictionary containing some data associated to each point of this - particular streamline. Each key `k` is mapped to a ndarray of - shape (Nt, Mk), where `Nt` is the number of points of this streamline - and `Mk` is the dimension of the data associated with key `k`. + particular streamline. Each key ``k`` is mapped to a ndarray of shape + (Nt, Mk), where ``Nt`` is the number of points of this streamline and + ``Mk`` is the dimension of the data associated with key ``k``. """ def __init__(self, streamline, data_for_streamline, data_for_points): self.streamline = np.asarray(streamline) @@ -215,7 +215,7 @@ class Tractogram(object): choice as long as you provide the correct `affine_to_rasmm` matrix, at construction time, that brings the streamlines back to *RAS+*, *mm* space, where the coordinates (0,0,0) corresponds to the center of the voxel - (opposed to a corner). + (as opposed to the corner of the voxel). Attributes ---------- @@ -224,18 +224,18 @@ class Tractogram(object): shape ($N_t$, 3) where $N_t$ is the number of points of streamline $t$. data_per_streamline : :class:`PerArrayDict` object - Dictionary where the items are (str, 2D array). - Each key represents an information $i$ to be kept alongside every - streamline, and its associated value is a 2D array of shape - ($T$, $P_i$) where $T$ is the number of streamlines and $P_i$ is - the number of values to store for that particular information $i$. + Dictionary where the items are (str, 2D array). Each key represents a + piece of information $i$ to be kept alongside every streamline, and its + associated value is a 2D array of shape ($T$, $P_i$) where $T$ is the + number of streamlines and $P_i$ is the number of values to store for + that particular piece of information $i$. data_per_point : :class:`PerArraySequenceDict` object - Dictionary where the items are (str, :class:`ArraySequence`). - Each key represents an information $i$ to be kept alongside every - point of every streamline, and its associated value is an iterable - of ndarrays of shape ($N_t$, $M_i$) where $N_t$ is the number of - points for a particular streamline $t$ and $M_i$ is the number - values to store for that particular information $i$. + Dictionary where the items are (str, :class:`ArraySequence`). Each key + represents a piece of information $i$ to be kept alongside every point + of every streamline, and its associated value is an iterable of + ndarrays of shape ($N_t$, $M_i$) where $N_t$ is the number of points + for a particular streamline $t$ and $M_i$ is the number values to store + for that particular piece of information $i$. """ def __init__(self, streamlines=None, data_per_streamline=None, @@ -424,7 +424,7 @@ class LazyTractogram(Tractogram): choice as long as you provide the correct `affine_to_rasmm` matrix, at construction time, that brings the streamlines back to *RAS+*, *mm* space, where the coordinates (0,0,0) corresponds to the center of the voxel - (opposed to a corner). + (as opposed to the corner of the voxel). Attributes ---------- @@ -432,21 +432,21 @@ class LazyTractogram(Tractogram): Generator function yielding streamlines. Each streamline is an ndarray of shape ($N_t$, 3) where $N_t$ is the number of points of streamline $t$. - data_per_streamline : :class:`LazyDict` object + data_per_streamline : instance of :class:`LazyDict` Dictionary where the items are (str, instantiated generator). - Each key represents an information $i$ to be kept alongside every - streamline, and its associated value is a generator function - yielding that information via ndarrays of shape ($P_i$,) where - $P_i$ is the number of values to store for that particular - information $i$. + Each key represents a piece of information $i$ to be kept alongside + every streamline, and its associated value is a generator function + yielding that information via ndarrays of shape ($P_i$,) where $P_i$ is + the number of values to store for that particular piece of information + $i$. data_per_point : :class:`LazyDict` object - Dictionary where the items are (str, instantiated generator). - Each key represents an information $i$ to be kept alongside every - point of every streamline, and its associated value is a generator - function yielding that information via ndarrays of shape - ($N_t$, $M_i$) where $N_t$ is the number of points for a particular - streamline $t$ and $M_i$ is the number of values to store for - that particular information $i$. + Dictionary where the items are (str, instantiated generator). Each key + represents a piece of information $i$ to be kept alongside every point + of every streamline, and its associated value is a generator function + yielding that information via ndarrays of shape ($N_t$, $M_i$) where + $N_t$ is the number of points for a particular streamline $t$ and $M_i$ + is the number of values to store for that particular piece of + information $i$. Notes ----- @@ -599,7 +599,6 @@ def _apply_affine(): def _set_streamlines(self, value): if value is not None and not callable(value): raise TypeError("`streamlines` must be a generator function.") - self._streamlines = value @property From 8952c23c053e864de0f05ef5c2ab47b5d5c57074 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 7 Jun 2016 09:08:28 -0400 Subject: [PATCH 132/135] Renamed LazyTractogram.create_from to Lazy_Tractogram.from_data_func --- nibabel/streamlines/tests/test_tractogram.py | 8 ++++---- nibabel/streamlines/tractogram.py | 2 +- nibabel/streamlines/trk.py | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index d5c1e684e8..f12a6bf0dc 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -600,11 +600,11 @@ def test_lazy_tractogram_creation(self): for i in range(2): assert_tractogram_equal(tractogram, DATA['tractogram']) - def test_lazy_tractogram_create_from(self): + def test_lazy_tractogram_from_data_func(self): # Create an empty `LazyTractogram` yielding nothing. _empty_data_gen = lambda: iter([]) - tractogram = LazyTractogram.create_from(_empty_data_gen) + tractogram = LazyTractogram.from_data_func(_empty_data_gen) check_tractogram(tractogram) # Create `LazyTractogram` from a generator function yielding TractogramItem. @@ -623,11 +623,11 @@ def _data_gen(): data_for_streamline, data_for_points) - tractogram = LazyTractogram.create_from(_data_gen) + tractogram = LazyTractogram.from_data_func(_data_gen) assert_tractogram_equal(tractogram, DATA['tractogram']) # Creating a LazyTractogram from not a corouting should raise an error. - assert_raises(TypeError, LazyTractogram.create_from, _data_gen()) + assert_raises(TypeError, LazyTractogram.from_data_func, _data_gen()) def test_lazy_tractogram_getitem(self): assert_raises(NotImplementedError, diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 5a8d65b3ba..cc0718eb2b 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -530,7 +530,7 @@ def _gen(key): return lazy_tractogram @classmethod - def create_from(cls, data_func): + def from_data_func(cls, data_func): """ Creates an instance from a generator function. The generator function must yield :class:`TractogramItem` objects. diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 4f7ea7a8eb..281c25854d 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -383,7 +383,7 @@ def _read(): data_for_streamline, data_for_points) - tractogram = LazyTractogram.create_from(_read) + tractogram = LazyTractogram.from_data_func(_read) else: trk_reader = cls._read(fileobj, hdr) From f34b779fa94ec879f422c370af6361df9874e9cb Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Mon, 6 Jun 2016 20:07:02 -0700 Subject: [PATCH 133/135] RF: simplify constructor for sliceable dicts Try dropping the special cases for the contructors for sliceable dicts, and deal with None as input value in the tractogram properties. --- nibabel/streamlines/tractogram.py | 43 +++++++++++++++++-------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index cc0718eb2b..222e94acd4 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -25,18 +25,16 @@ class SliceableDataDict(collections.MutableMapping): This container behaves like a standard dictionary but extends key access to allow keys for key access to be indices slicing into the contained ndarray values. + + Parameters + ---------- + \*args : + \*\*kwargs : + Positional and keyword arguments, passed straight through the ``dict`` + constructor. """ def __init__(self, *args, **kwargs): self.store = dict() - # Use the 'update' method to set the keys. - if len(args) == 1: - if args[0] is None: - return - - if isinstance(args[0], SliceableDataDict): - self.update(**args[0]) - return - self.update(dict(*args, **kwargs)) def __getitem__(self, key): @@ -47,9 +45,9 @@ def __getitem__(self, key): # Try to interpret key as an index/slice for every data element, in # which case we perform (maybe advanced) indexing on every element of - # the dictionnary. + # the dictionary. idx = key - new_dict = type(self)(None) + new_dict = type(self)() try: for k, v in self.items(): new_dict[k] = v[idx] @@ -81,8 +79,18 @@ class PerArrayDict(SliceableDataDict): In addition, it makes sure the amount of data contained in those ndarrays matches the number of streamlines given at the instantiation of this instance. + + Parameters + ---------- + nb_elements : None or int, optional + Number of elements per value in each key, value pair or None for not + specified. + \*args : + \*\*kwargs : + Positional and keyword arguments, passed straight through the ``dict`` + constructor. """ - def __init__(self, nb_elements, *args, **kwargs): + def __init__(self, nb_elements=None, *args, **kwargs): self.nb_elements = nb_elements super(PerArrayDict, self).__init__(*args, **kwargs) @@ -105,7 +113,7 @@ def __setitem__(self, key, value): self.store[key] = value -class PerArraySequenceDict(SliceableDataDict): +class PerArraySequenceDict(PerArrayDict): """ Dictionary for which key access can do slicing on the values. This container behaves like a standard dictionary but extends key access to @@ -116,10 +124,6 @@ class PerArraySequenceDict(SliceableDataDict): sequences matches the number of elements given at the instantiation of the instance. """ - def __init__(self, nb_elements, *args, **kwargs): - self.nb_elements = nb_elements - super(PerArraySequenceDict, self).__init__(*args, **kwargs) - def __setitem__(self, key, value): value = ArraySequence(value) @@ -285,7 +289,8 @@ def data_per_streamline(self): @data_per_streamline.setter def data_per_streamline(self, value): - self._data_per_streamline = PerArrayDict(len(self.streamlines), value) + self._data_per_streamline = PerArrayDict( + len(self.streamlines), {} if value is None else value) @property def data_per_point(self): @@ -294,7 +299,7 @@ def data_per_point(self): @data_per_point.setter def data_per_point(self, value): self._data_per_point = PerArraySequenceDict( - self.streamlines.nb_elements, value) + self.streamlines.nb_elements, {} if value is None else value) @property def affine_to_rasmm(self): From 2dd2d4a7873eb11ce8799db497aff794229f1d4c Mon Sep 17 00:00:00 2001 From: Matthew Brett Date: Tue, 7 Jun 2016 10:57:17 -0700 Subject: [PATCH 134/135] DOC: fix hash typo in docstring --- nibabel/streamlines/trk.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nibabel/streamlines/trk.py b/nibabel/streamlines/trk.py index 281c25854d..9eeef0f1cd 100644 --- a/nibabel/streamlines/trk.py +++ b/nibabel/streamlines/trk.py @@ -93,8 +93,8 @@ def get_affine_trackvis_to_rasmm(header): The streamlines in a trackvis file are in 'voxelmm' space, where the coordinates refer to the corner of the voxel. - Compute the # affine matrix that will bring them back to RAS+ mm space, - where the coordinates refer to the center of the voxel. + Compute the affine matrix that will bring them back to RAS+ mm space, where + the coordinates refer to the center of the voxel. Parameters ---------- From 0a22b931de4c3e001a50ac7010ac11f46592f74a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Tue, 7 Jun 2016 14:53:28 -0400 Subject: [PATCH 135/135] Renamed nb_elements to total_nb_rows --- nibabel/streamlines/array_sequence.py | 6 +++--- .../streamlines/tests/test_array_sequence.py | 4 ++-- nibabel/streamlines/tests/test_tractogram.py | 12 +++++------ nibabel/streamlines/tractogram.py | 20 +++++++++---------- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index ce36f83d47..b41ceb0f90 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -102,8 +102,8 @@ def common_shape(self): return self._data.shape[1:] @property - def nb_elements(self): - """ Total number of elements in this array sequence. """ + def total_nb_rows(self): + """ Total number of rows in this array sequence. """ return np.sum(self._lengths) @property @@ -112,7 +112,7 @@ def data(self): return self._data def _get_next_offset(self): - """ Offset in ``self._data`` at which to write next element """ + """ Offset in ``self._data`` at which to write next rowelement """ if len(self._offsets) == 0: return 0 imax = np.argmax(self._offsets) diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index 4f52c5d987..a2ebd3a22e 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -116,8 +116,8 @@ def test_arraysequence_iter(self): def test_arraysequence_copy(self): orig = SEQ_DATA['seq'] seq = orig.copy() - n_rows = seq.nb_elements - assert_equal(n_rows, orig.nb_elements) + n_rows = seq.total_nb_rows + assert_equal(n_rows, orig.total_nb_rows) assert_array_equal(seq._data, orig._data[:n_rows]) assert_true(seq._data is not orig._data) assert_array_equal(seq._offsets, orig._offsets) diff --git a/nibabel/streamlines/tests/test_tractogram.py b/nibabel/streamlines/tests/test_tractogram.py index f12a6bf0dc..76f06dff0e 100644 --- a/nibabel/streamlines/tests/test_tractogram.py +++ b/nibabel/streamlines/tests/test_tractogram.py @@ -186,9 +186,9 @@ class TestPerArraySequenceDict(unittest.TestCase): def test_per_array_sequence_dict_creation(self): # Create a PerArraySequenceDict object using another # PerArraySequenceDict object. - nb_elements = DATA['tractogram'].streamlines.nb_elements + total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows data_per_point = DATA['tractogram'].data_per_point - data_dict = PerArraySequenceDict(nb_elements, data_per_point) + data_dict = PerArraySequenceDict(total_nb_rows, data_per_point) assert_equal(data_dict.keys(), data_per_point.keys()) for k in data_dict.keys(): assert_arrays_equal(data_dict[k], data_per_point[k]) @@ -199,7 +199,7 @@ def test_per_array_sequence_dict_creation(self): # Create a PerArraySequenceDict object using an existing dict object. data_per_point = DATA['data_per_point'] - data_dict = PerArraySequenceDict(nb_elements, data_per_point) + data_dict = PerArraySequenceDict(total_nb_rows, data_per_point) assert_equal(data_dict.keys(), data_per_point.keys()) for k in data_dict.keys(): assert_arrays_equal(data_dict[k], data_per_point[k]) @@ -209,7 +209,7 @@ def test_per_array_sequence_dict_creation(self): # Create a PerArraySequenceDict object using keyword arguments. data_per_point = DATA['data_per_point'] - data_dict = PerArraySequenceDict(nb_elements, **data_per_point) + data_dict = PerArraySequenceDict(total_nb_rows, **data_per_point) assert_equal(data_dict.keys(), data_per_point.keys()) for k in data_dict.keys(): assert_arrays_equal(data_dict[k], data_per_point[k]) @@ -218,8 +218,8 @@ def test_per_array_sequence_dict_creation(self): assert_equal(len(data_dict), len(data_per_point)-1) def test_getitem(self): - nb_elements = DATA['tractogram'].streamlines.nb_elements - sdict = PerArraySequenceDict(nb_elements, DATA['data_per_point']) + total_nb_rows = DATA['tractogram'].streamlines.total_nb_rows + sdict = PerArraySequenceDict(total_nb_rows, DATA['data_per_point']) assert_raises(KeyError, sdict.__getitem__, 'invalid') diff --git a/nibabel/streamlines/tractogram.py b/nibabel/streamlines/tractogram.py index 222e94acd4..c33f707d1c 100644 --- a/nibabel/streamlines/tractogram.py +++ b/nibabel/streamlines/tractogram.py @@ -82,16 +82,16 @@ class PerArrayDict(SliceableDataDict): Parameters ---------- - nb_elements : None or int, optional - Number of elements per value in each key, value pair or None for not + n_rows : None or int, optional + Number of rows per value in each key, value pair or None for not specified. \*args : \*\*kwargs : Positional and keyword arguments, passed straight through the ``dict`` constructor. """ - def __init__(self, nb_elements=None, *args, **kwargs): - self.nb_elements = nb_elements + def __init__(self, n_rows=None, *args, **kwargs): + self.n_rows = n_rows super(PerArrayDict, self).__init__(*args, **kwargs) def __setitem__(self, key, value): @@ -105,9 +105,9 @@ def __setitem__(self, key, value): raise ValueError("data_per_streamline must be a 2D array.") # We make sure there is the right amount of values - if self.nb_elements is not None and len(value) != self.nb_elements: + if self.n_rows is not None and len(value) != self.n_rows: msg = ("The number of values ({0}) should match n_elements " - "({1}).").format(len(value), self.nb_elements) + "({1}).").format(len(value), self.n_rows) raise ValueError(msg) self.store[key] = value @@ -128,10 +128,10 @@ def __setitem__(self, key, value): value = ArraySequence(value) # We make sure there is the right amount of data. - if (self.nb_elements is not None and - value.nb_elements != self.nb_elements): + if (self.n_rows is not None and + value.total_nb_rows != self.n_rows): msg = ("The number of values ({0}) should match " - "({1}).").format(value.nb_elements, self.nb_elements) + "({1}).").format(value.total_nb_rows, self.n_rows) raise ValueError(msg) self.store[key] = value @@ -299,7 +299,7 @@ def data_per_point(self): @data_per_point.setter def data_per_point(self, value): self._data_per_point = PerArraySequenceDict( - self.streamlines.nb_elements, {} if value is None else value) + self.streamlines.total_nb_rows, {} if value is None else value) @property def affine_to_rasmm(self):