Skip to content

Commit b55e6c8

Browse files
committed
A working prototype of the new streamlines API
1 parent d719976 commit b55e6c8

File tree

16 files changed

+1599
-488
lines changed

16 files changed

+1599
-488
lines changed

nibabel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
apply_orientation, aff2axcodes)
6464
from .imageclasses import class_map, ext_map
6565
from . import trackvis
66+
from .streamlines import Streamlines
6667

6768
# be friendly on systems with ancient numpy -- no tests, but at least
6869
# importable
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
""" Benchmarks for load and save of streamlines
2+
3+
Run benchmarks with::
4+
5+
import nibabel as nib
6+
nib.bench()
7+
8+
If you have doctests enabled by default in nose (with a noserc file or
9+
environment variable), and you have a numpy version <= 1.6.1, this will also run
10+
the doctests, let's hope they pass.
11+
12+
Run this benchmark with:
13+
14+
nosetests -s --match '(?:^|[\\b_\\.//-])[Bb]ench' /path/to/bench_streamlines.py
15+
"""
16+
from __future__ import division, print_function
17+
18+
import os
19+
import numpy as np
20+
21+
from nibabel.externals.six import BytesIO
22+
from nibabel.externals.six.moves import zip
23+
24+
from nibabel.testing import assert_arrays_equal
25+
26+
from numpy.testing import assert_array_equal
27+
from nibabel.streamlines.base_format import Streamlines
28+
from nibabel.streamlines import TrkFile
29+
30+
import nibabel as nib
31+
import nibabel.trackvis as tv
32+
33+
from numpy.testing import measure
34+
35+
36+
def bench_load_trk():
37+
NB_STREAMLINES = 1000
38+
NB_POINTS = 1000
39+
points = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)]
40+
repeat = 20
41+
42+
trk_file = BytesIO()
43+
trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES))
44+
tv.write(trk_file, trk)
45+
46+
mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, lazy_load=False)', repeat)
47+
print("\nNew: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new))
48+
49+
mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file)', repeat)
50+
print("Old: Loaded %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old))
51+
print("Speedup of %2f" % (mtime_old/mtime_new))
52+
53+
# Points and scalars
54+
scalars = [np.random.rand(NB_POINTS, 10).astype('float32') for i in range(NB_STREAMLINES)]
55+
56+
trk_file = BytesIO()
57+
trk = list(zip(points, scalars, [None]*NB_STREAMLINES))
58+
tv.write(trk_file, trk)
59+
60+
mtime_new = measure('trk_file.seek(0, os.SEEK_SET); nib.streamlines.load(trk_file, lazy_load=False)', repeat)
61+
print("New: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new))
62+
63+
mtime_old = measure('trk_file.seek(0, os.SEEK_SET); tv.read(trk_file)', repeat)
64+
print("Old: Loaded %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old))
65+
print("Speedup of %2f" % (mtime_old/mtime_new))
66+
67+
68+
def bench_save_trk():
69+
NB_STREAMLINES = 100
70+
NB_POINTS = 1000
71+
points = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)]
72+
repeat = 10
73+
74+
# Only points
75+
streamlines = Streamlines(points)
76+
trk_file_new = BytesIO()
77+
78+
mtime_new = measure('trk_file_new.seek(0, os.SEEK_SET); TrkFile.save(streamlines, trk_file_new)', repeat)
79+
print("\nNew: Saved %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_new))
80+
81+
trk_file_old = BytesIO()
82+
trk = list(zip(points, [None]*NB_STREAMLINES, [None]*NB_STREAMLINES))
83+
mtime_old = measure('trk_file_old.seek(0, os.SEEK_SET); tv.write(trk_file_old, trk)', repeat)
84+
print("Old: Saved %d streamlines in %6.2f" % (NB_STREAMLINES, mtime_old))
85+
print("Speedup of %2f" % (mtime_old/mtime_new))
86+
87+
trk_file_new.seek(0, os.SEEK_SET)
88+
trk_file_old.seek(0, os.SEEK_SET)
89+
streams, hdr = tv.read(trk_file_old)
90+
91+
for pts, A in zip(points, streams):
92+
assert_array_equal(pts, A[0])
93+
94+
trk = nib.streamlines.load(trk_file_new, lazy_load=False)
95+
96+
assert_arrays_equal(points, trk.points)
97+
98+
# Points and scalars
99+
scalars = [np.random.rand(NB_POINTS, 3).astype('float32') for i in range(NB_STREAMLINES)]
100+
streamlines = Streamlines(points, scalars=scalars)
101+
trk_file_new = BytesIO()
102+
103+
mtime_new = measure('trk_file_new.seek(0, os.SEEK_SET); TrkFile.save(streamlines, trk_file_new)', repeat)
104+
print("New: Saved %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_new))
105+
106+
trk_file_old = BytesIO()
107+
trk = list(zip(points, scalars, [None]*NB_STREAMLINES))
108+
mtime_old = measure('trk_file_old.seek(0, os.SEEK_SET); tv.write(trk_file_old, trk)', repeat)
109+
print("Old: Saved %d streamlines with scalars in %6.2f" % (NB_STREAMLINES, mtime_old))
110+
print("Speedup of %2f" % (mtime_old/mtime_new))
111+
112+
trk_file_new.seek(0, os.SEEK_SET)
113+
trk_file_old.seek(0, os.SEEK_SET)
114+
streams, hdr = tv.read(trk_file_old)
115+
116+
for pts, scal, A in zip(points, scalars, streams):
117+
assert_array_equal(pts, A[0])
118+
assert_array_equal(scal, A[1])
119+
120+
trk = nib.streamlines.load(trk_file_new, lazy_load=False)
121+
122+
assert_arrays_equal(points, trk.points)
123+
assert_arrays_equal(scalars, trk.scalars)
124+
125+
126+
if __name__ == '__main__':
127+
bench_save_trk()

nibabel/externals/six.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ class _MovedItems(types.ModuleType):
143143
MovedAttribute("StringIO", "StringIO", "io"),
144144
MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"),
145145
MovedAttribute("zip", "itertools", "builtins", "izip", "zip"),
146+
MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"),
146147

147148
MovedModule("builtins", "__builtin__"),
148149
MovedModule("configparser", "ConfigParser"),

nibabel/streamlines/__init__.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,9 @@
1-
from nibabel.openers import Opener
21

3-
from nibabel.streamlines.utils import detect_format
2+
from nibabel.streamlines.utils import load, save
43

4+
from nibabel.streamlines.base_format import Streamlines
5+
from nibabel.streamlines.header import Field
56

6-
def load(fileobj):
7-
''' Load a file of streamlines, return instance associated to file format
8-
9-
Parameters
10-
----------
11-
fileobj : string or file-like object
12-
If string, a filename; otherwise an open file-like object
13-
pointing to a streamlines file (and ready to read from the beginning
14-
of the streamlines file's header)
15-
16-
Returns
17-
-------
18-
obj : instance of ``StreamlineFile``
19-
Returns an instance of a ``StreamlineFile`` subclass corresponding to
20-
the format of the streamlines file ``fileobj``.
21-
'''
22-
fileobj = Opener(fileobj)
23-
streamlines_file = detect_format(fileobj)
24-
return streamlines_file.load(fileobj)
7+
from nibabel.streamlines.trk import TrkFile
8+
#from nibabel.streamlines.trk import TckFile
9+
#from nibabel.streamlines.trk import VtkFile

nibabel/streamlines/base_format.py

Lines changed: 170 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,193 @@
11

2-
class StreamlineFile:
3-
@staticmethod
4-
def get_magic_number():
5-
raise NotImplementedError()
2+
from nibabel.streamlines.header import Field
63

7-
@staticmethod
8-
def is_correct_format(cls, fileobj):
9-
raise NotImplementedError()
4+
from ..externals.six.moves import zip_longest
5+
6+
7+
class HeaderError(Exception):
8+
pass
9+
10+
11+
class DataError(Exception):
12+
pass
13+
14+
15+
class Streamlines(object):
16+
''' Class containing information about streamlines.
17+
18+
Streamlines objects have three main properties: ``points``, ``scalars``
19+
and ``properties``. Streamlines objects can be iterate over producing
20+
tuple of ``points``, ``scalars`` and ``properties`` for each streamline.
21+
22+
Parameters
23+
----------
24+
points : sequence of ndarray of shape (N, 3)
25+
Sequence of T streamlines. One streamline is an ndarray of shape (N, 3)
26+
where N is the number of points in a streamline.
27+
28+
scalars : sequence of ndarray of shape (N, M)
29+
Sequence of T ndarrays of shape (N, M) where T is the number of
30+
streamlines defined by ``points``, N is the number of points
31+
for a particular streamline and M is the number of scalars
32+
associated to each point (excluding the three coordinates).
33+
34+
properties : sequence of ndarray of shape (P,)
35+
Sequence of T ndarrays of shape (P,) where T is the number of
36+
streamlines defined by ``points``, P is the number of properties
37+
associated to each streamlines.
38+
39+
hdr : dict
40+
Header containing meta information about the streamlines. For a list
41+
of common header's fields to use as keys see `nibabel.streamlines.Field`.
42+
'''
43+
def __init__(self, points=[], scalars=[], properties=[], hdr={}):
44+
self.hdr = hdr
45+
46+
self.points = points
47+
self.scalars = scalars
48+
self.properties = properties
49+
self.data = lambda: zip_longest(self.points, self.scalars, self.properties, fillvalue=[])
50+
51+
try:
52+
self.length = len(points)
53+
except:
54+
if Field.NB_STREAMLINES in hdr:
55+
self.length = hdr[Field.NB_STREAMLINES]
56+
else:
57+
raise HeaderError(("Neither parameter 'points' nor 'hdr' contain information about"
58+
" number of streamlines. Use key '{0}' to set the number of "
59+
"streamlines in 'hdr'.").format(Field.NB_STREAMLINES))
1060

1161
def get_header(self):
12-
raise NotImplementedError()
62+
return self.hdr
63+
64+
@property
65+
def points(self):
66+
return self._points()
67+
68+
@points.setter
69+
def points(self, value):
70+
self._points = value if callable(value) else (lambda: value)
71+
72+
@property
73+
def scalars(self):
74+
return self._scalars()
75+
76+
@scalars.setter
77+
def scalars(self, value):
78+
self._scalars = value if callable(value) else lambda: value
79+
80+
@property
81+
def properties(self):
82+
return self._properties()
1383

14-
def get_streamlines(self, as_generator=False):
84+
@properties.setter
85+
def properties(self, value):
86+
self._properties = value if callable(value) else lambda: value
87+
88+
def __iter__(self):
89+
return self.data()
90+
91+
def __len__(self):
92+
return self.length
93+
94+
95+
class StreamlinesFile:
96+
''' Convenience class to encapsulate streamlines file format. '''
97+
98+
@classmethod
99+
def get_magic_number(cls):
100+
''' Return streamlines file's magic number. '''
15101
raise NotImplementedError()
16102

17-
def get_scalars(self, as_generator=False):
103+
@classmethod
104+
def is_correct_format(cls, fileobj):
105+
''' Check if the file has the right streamlines file format.
106+
107+
Parameters
108+
----------
109+
fileobj : string or file-like object
110+
If string, a filename; otherwise an open file-like object
111+
pointing to a streamlines file (and ready to read from the
112+
beginning of the header)
113+
114+
Returns
115+
-------
116+
is_correct_format : boolean
117+
Returns True if `fileobj` is in the right streamlines file format.
118+
'''
18119
raise NotImplementedError()
19120

20-
def get_properties(self, as_generator=False):
121+
@classmethod
122+
def get_empty_header(cls):
123+
''' Return an empty streamlines file's header. '''
21124
raise NotImplementedError()
22125

23126
@classmethod
24-
def load(cls, fileobj):
127+
def load(cls, fileobj, lazy_load=True):
128+
''' Loads streamlines from a file-like object.
129+
130+
Parameters
131+
----------
132+
fileobj : string or file-like object
133+
If string, a filename; otherwise an open file-like object
134+
pointing to a streamlines file (and ready to read from the
135+
beginning of the header)
136+
137+
lazy_load : boolean
138+
Load streamlines in a lazy manner i.e. they will not be kept
139+
in memory. For postprocessing speed, turn off this option.
140+
141+
Returns
142+
-------
143+
streamlines : Streamlines object
144+
Returns an object containing streamlines' data and header
145+
information. See 'nibabel.Streamlines'.
146+
'''
25147
raise NotImplementedError()
26148

27-
def save(self, filename):
149+
@classmethod
150+
def save(cls, streamlines, fileobj):
151+
''' Saves streamlines to a file-like object.
152+
153+
Parameters
154+
----------
155+
streamlines : Streamlines object
156+
Object containing streamlines' data and header information.
157+
See 'nibabel.Streamlines'.
158+
159+
fileobj : string or file-like object
160+
If string, a filename; otherwise an open file-like object
161+
opened and ready to write.
162+
'''
28163
raise NotImplementedError()
29164

30-
def __iter__(self):
165+
@staticmethod
166+
def pretty_print(streamlines):
167+
''' Gets a formatted string contaning header's information
168+
relevant to the streamlines file format.
169+
170+
Parameters
171+
----------
172+
streamlines : Streamlines object
173+
Object containing streamlines' data and header information.
174+
See 'nibabel.Streamlines'.
175+
176+
Returns
177+
-------
178+
info : string
179+
Header's information relevant to the streamlines file format.
180+
'''
31181
raise NotImplementedError()
32182

33183

34-
class DynamicStreamlineFile(StreamlineFile):
184+
class DynamicStreamlineFile(StreamlinesFile):
185+
''' Convenience class to encapsulate streamlines file format
186+
that supports appending streamlines to an existing file.
187+
'''
188+
35189
def append(self, streamlines):
36190
raise NotImplementedError()
37191

38192
def __iadd__(self, streamlines):
39-
return self.append(streamlines)
193+
return self.append(streamlines)

0 commit comments

Comments
 (0)