Skip to content

NF: allow truncated last track in trackvis file #346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 19, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions Changelog
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,17 @@ and Stephan Gerhard (SG).

References like "pr/298" refer to github pull request numbers.

* Upcoming

* Trackvis reader will now allow final streamline to have fewer points that
tne numbe declared in the header, with ``strict=False`` argument to
``read`` function;
* Minor API breakage in trackvis reader. We are now raising a DataError if
there are too few streamlines in the file, instead of a HeaderError. We
are raising a DataError if the track is truncated when ``strict=True``
(the default), rather than a TypeError when trying to create the points
array.

* 2.0.1 (Saturday 27 June 2015)

Contributions from Ben Cipollini, Chris Markiewicz, Alexandre Gramfort,
Expand Down
36 changes: 36 additions & 0 deletions nibabel/tests/test_trackvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,3 +591,39 @@ def f(pts): # from vx to mm
out_f.seek(0)
tvf2 = tv.TrackvisFile.from_file(out_f, points_space='rasmm')
assert_true(streamlist_equal(fancy_rasmm_streams, tvf2.streamlines))


def test_read_truncated():
# Test behavior when last track contains fewer points than specified
out_f = BytesIO()
xyz0 = np.tile(np.arange(5).reshape(5, 1), (1, 3))
xyz1 = np.tile(np.arange(5).reshape(5, 1) + 10, (1, 3))
streams = [(xyz0, None, None), (xyz1, None, None)]
tv.write(out_f, streams, {})
# Truncate the last stream by one point
value = out_f.getvalue()[:-(3 * 4)]
new_f = BytesIO(value)
# By default, raises a DataError
assert_raises(tv.DataError, tv.read, new_f)
# This corresponds to strict mode
new_f.seek(0)
assert_raises(tv.DataError, tv.read, new_f, strict=True)
# lenient error mode lets this error pass, with truncated track
short_streams = [(xyz0, None, None), (xyz1[:-1], None, None)]
new_f.seek(0)
streams2, hdr = tv.read(new_f, strict=False)
assert_true(streamlist_equal(streams2, short_streams))
# Check that lenient works when number of tracks is 0, where 0 signals to
# the reader to read until the end of the file.
again_hdr = hdr.copy()
assert_equal(again_hdr['n_count'], 2)
again_hdr['n_count'] = 0
again_bytes = again_hdr.tostring() + value[again_hdr.itemsize:]
again_f = BytesIO(again_bytes)
streams2, _ = tv.read(again_f, strict=False)
assert_true(streamlist_equal(streams2, short_streams))
# Set count to one above actual number of tracks, always raise error
again_hdr['n_count'] = 3
again_bytes = again_hdr.tostring() + value[again_hdr.itemsize:]
again_f = BytesIO(again_bytes)
assert_raises(tv.DataError, tv.read, again_f, strict=False)
45 changes: 32 additions & 13 deletions nibabel/trackvis.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ class DataError(Exception):
"""


def read(fileobj, as_generator=False, points_space=None):
''' Read trackvis file, return streamlines, header
def read(fileobj, as_generator=False, points_space=None, strict=True):
''' Read trackvis file from `fileobj`, return `streamlines`, `header`

Parameters
----------
Expand All @@ -116,6 +116,9 @@ def read(fileobj, as_generator=False, points_space=None):
voxel size. If 'rasmm' we'll convert the points to RAS mm space (real
space). For 'rasmm' we check if the affine is set and matches the voxel
sizes and voxel order.
strict : {True, False}, optional
If True, raise error on read for badly-formed file. If False, let pass
files with last track having too few points.

Returns
-------
Expand Down Expand Up @@ -192,22 +195,35 @@ def read(fileobj, as_generator=False, points_space=None):
raise HeaderError('Unexpected negative n_count')

def track_gen():
n_streams = 0
# For case where there are no scalars or no properties
scalars = None
ps = None
while True:
n_streams = 0
# stream_count == 0 signals read to end of file
n_streams_required = stream_count if stream_count != 0 else np.inf
end_of_file = False
while not end_of_file and n_streams < n_streams_required:
n_str = fileobj.read(4)
if len(n_str) < 4:
if stream_count:
raise HeaderError(
'Expecting %s points, found only %s' % (
stream_count, n_streams))
break
n_pts = struct.unpack(i_fmt, n_str)[0]
pts_str = fileobj.read(n_pts * pt_size)
# Check if we got as many bytes as we expect for these points
exp_len = n_pts * pt_size
pts_str = fileobj.read(exp_len)
if len(pts_str) != exp_len:
# Short of bytes, should we raise an error or continue?
actual_n_pts = int(len(pts_str) / pt_size)
if actual_n_pts != n_pts:
if strict == True:
raise DataError('Expecting {0} points for stream {1}, '
'found {2}'.format(
n_pts, n_streams, actual_n_pts))
n_pts = actual_n_pts
end_of_file = True
# Cast bytes to points array
pts = np.ndarray(shape=(n_pts, pt_cols), dtype=f4dt,
buffer=pts_str)
# Add properties
if n_p:
ps_str = fileobj.read(ps_size)
ps = np.ndarray(shape=(n_p,), dtype=f4dt, buffer=ps_str)
Expand All @@ -220,11 +236,14 @@ def track_gen():
scalars = pts[:, 3:]
yield (xyz, scalars, ps)
n_streams += 1
# deliberately misses case where stream_count is 0
if n_streams == stream_count:
fileobj.close_if_mine()
raise StopIteration
# Always close file if we opened it
fileobj.close_if_mine()
# Raise error if we didn't get as many streams as claimed
if n_streams_required != np.inf and n_streams < n_streams_required:
raise DataError(
'Expecting {0} streamlines, found only {1}'.format(
stream_count, n_streams))

streamlines = track_gen()
if not as_generator:
streamlines = list(streamlines)
Expand Down