diff --git a/Changelog b/Changelog index 8bd1a39839..b41108d210 100644 --- a/Changelog +++ b/Changelog @@ -24,6 +24,17 @@ and Stephan Gerhard (SG). References like "pr/298" refer to github pull request numbers. +* Upcoming + + * Trackvis reader will now allow final streamline to have fewer points that + tne numbe declared in the header, with ``strict=False`` argument to + ``read`` function; + * Minor API breakage in trackvis reader. We are now raising a DataError if + there are too few streamlines in the file, instead of a HeaderError. We + are raising a DataError if the track is truncated when ``strict=True`` + (the default), rather than a TypeError when trying to create the points + array. + * 2.0.1 (Saturday 27 June 2015) Contributions from Ben Cipollini, Chris Markiewicz, Alexandre Gramfort, diff --git a/nibabel/tests/test_trackvis.py b/nibabel/tests/test_trackvis.py index 00ec2f0732..94e5bf93bb 100644 --- a/nibabel/tests/test_trackvis.py +++ b/nibabel/tests/test_trackvis.py @@ -591,3 +591,39 @@ def f(pts): # from vx to mm out_f.seek(0) tvf2 = tv.TrackvisFile.from_file(out_f, points_space='rasmm') assert_true(streamlist_equal(fancy_rasmm_streams, tvf2.streamlines)) + + +def test_read_truncated(): + # Test behavior when last track contains fewer points than specified + out_f = BytesIO() + xyz0 = np.tile(np.arange(5).reshape(5, 1), (1, 3)) + xyz1 = np.tile(np.arange(5).reshape(5, 1) + 10, (1, 3)) + streams = [(xyz0, None, None), (xyz1, None, None)] + tv.write(out_f, streams, {}) + # Truncate the last stream by one point + value = out_f.getvalue()[:-(3 * 4)] + new_f = BytesIO(value) + # By default, raises a DataError + assert_raises(tv.DataError, tv.read, new_f) + # This corresponds to strict mode + new_f.seek(0) + assert_raises(tv.DataError, tv.read, new_f, strict=True) + # lenient error mode lets this error pass, with truncated track + short_streams = [(xyz0, None, None), (xyz1[:-1], None, None)] + new_f.seek(0) + streams2, hdr = tv.read(new_f, strict=False) + assert_true(streamlist_equal(streams2, short_streams)) + # Check that lenient works when number of tracks is 0, where 0 signals to + # the reader to read until the end of the file. + again_hdr = hdr.copy() + assert_equal(again_hdr['n_count'], 2) + again_hdr['n_count'] = 0 + again_bytes = again_hdr.tostring() + value[again_hdr.itemsize:] + again_f = BytesIO(again_bytes) + streams2, _ = tv.read(again_f, strict=False) + assert_true(streamlist_equal(streams2, short_streams)) + # Set count to one above actual number of tracks, always raise error + again_hdr['n_count'] = 3 + again_bytes = again_hdr.tostring() + value[again_hdr.itemsize:] + again_f = BytesIO(again_bytes) + assert_raises(tv.DataError, tv.read, again_f, strict=False) diff --git a/nibabel/trackvis.py b/nibabel/trackvis.py index 3af06802b5..a76baaff75 100644 --- a/nibabel/trackvis.py +++ b/nibabel/trackvis.py @@ -95,8 +95,8 @@ class DataError(Exception): """ -def read(fileobj, as_generator=False, points_space=None): - ''' Read trackvis file, return streamlines, header +def read(fileobj, as_generator=False, points_space=None, strict=True): + ''' Read trackvis file from `fileobj`, return `streamlines`, `header` Parameters ---------- @@ -116,6 +116,9 @@ def read(fileobj, as_generator=False, points_space=None): voxel size. If 'rasmm' we'll convert the points to RAS mm space (real space). For 'rasmm' we check if the affine is set and matches the voxel sizes and voxel order. + strict : {True, False}, optional + If True, raise error on read for badly-formed file. If False, let pass + files with last track having too few points. Returns ------- @@ -192,22 +195,35 @@ def read(fileobj, as_generator=False, points_space=None): raise HeaderError('Unexpected negative n_count') def track_gen(): - n_streams = 0 # For case where there are no scalars or no properties scalars = None ps = None - while True: + n_streams = 0 + # stream_count == 0 signals read to end of file + n_streams_required = stream_count if stream_count != 0 else np.inf + end_of_file = False + while not end_of_file and n_streams < n_streams_required: n_str = fileobj.read(4) if len(n_str) < 4: - if stream_count: - raise HeaderError( - 'Expecting %s points, found only %s' % ( - stream_count, n_streams)) break n_pts = struct.unpack(i_fmt, n_str)[0] - pts_str = fileobj.read(n_pts * pt_size) + # Check if we got as many bytes as we expect for these points + exp_len = n_pts * pt_size + pts_str = fileobj.read(exp_len) + if len(pts_str) != exp_len: + # Short of bytes, should we raise an error or continue? + actual_n_pts = int(len(pts_str) / pt_size) + if actual_n_pts != n_pts: + if strict == True: + raise DataError('Expecting {0} points for stream {1}, ' + 'found {2}'.format( + n_pts, n_streams, actual_n_pts)) + n_pts = actual_n_pts + end_of_file = True + # Cast bytes to points array pts = np.ndarray(shape=(n_pts, pt_cols), dtype=f4dt, buffer=pts_str) + # Add properties if n_p: ps_str = fileobj.read(ps_size) ps = np.ndarray(shape=(n_p,), dtype=f4dt, buffer=ps_str) @@ -220,11 +236,14 @@ def track_gen(): scalars = pts[:, 3:] yield (xyz, scalars, ps) n_streams += 1 - # deliberately misses case where stream_count is 0 - if n_streams == stream_count: - fileobj.close_if_mine() - raise StopIteration + # Always close file if we opened it fileobj.close_if_mine() + # Raise error if we didn't get as many streams as claimed + if n_streams_required != np.inf and n_streams < n_streams_required: + raise DataError( + 'Expecting {0} streamlines, found only {1}'.format( + stream_count, n_streams)) + streamlines = track_gen() if not as_generator: streamlines = list(streamlines)