Skip to content

Commit a848bb4

Browse files
author
Soichi Hayashi
authored
Merge pull request #1 from effigies/fix/streamlines_infnan
RF: Use bytearray/frombuffer and other minor fixes
2 parents 62c0cda + 0a75431 commit a848bb4

File tree

2 files changed

+35
-24
lines changed

2 files changed

+35
-24
lines changed

nibabel/streamlines/tck.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -386,8 +386,10 @@ def _read(cls, fileobj, header, buffer_size=4):
386386
"""
387387
dtype = header["_dtype"]
388388

389-
#align batch_size to be multiple of 3 within the specified buffer size
390-
batch_size = int(buffer_size * MEGABYTE / dtype.itemsize / 3) * 3
389+
coordinate_size = 3 * dtype.itemsize
390+
# Make buffer_size an integer and a multiple of coordinate_size.
391+
buffer_size = int(buffer_size * MEGABYTE)
392+
buffer_size += coordinate_size - (buffer_size % coordinate_size)
391393

392394
with Opener(fileobj) as f:
393395
start_position = f.tell()
@@ -397,38 +399,39 @@ def _read(cls, fileobj, header, buffer_size=4):
397399

398400
eof = False
399401
n_streams = 0
400-
leftover = np.empty((0,3), dtype='<f4')
402+
leftover = np.empty((0, 3), dtype='<f4')
401403
while not eof:
404+
buff = bytearray(buffer_size)
405+
n_read = f.readinto(buff)
406+
eof = n_read != buffer_size
407+
if eof:
408+
buff = buff[:n_read]
402409

403410
# read raw files from file
404-
raw_values = np.fromfile(f.fobj, dtype, batch_size)
405-
if len(raw_values) < batch_size:
406-
eof = True
411+
raw_values = np.frombuffer(buff, dtype=dtype)
407412

408-
# Convert raw_values into a list of little-endian tuples (for x,y,z coord)
409-
coords = raw_values.astype('<f4', copy=False).reshape([-1, 3])
413+
# Convert raw_values into a list of little-endian triples (for x,y,z coord)
414+
coords = raw_values.astype('<f4', copy=False).reshape((-1, 3))
410415

411-
# find stream delimiter locations (all NaNs)
412-
delims = np.where(np.all(np.isnan(coords), axis=1))[0]
416+
# Find stream delimiter locations (all NaNs)
417+
delims = np.where(np.isnan(coords).all(axis=1))[0]
418+
419+
if leftover.size:
420+
delims += leftover.shape[0]
421+
coords = np.vstack((leftover, coords))
413422

414-
# for each delimiters, yeild new streams
415423
begin = 0
416-
for i in range(0, len(delims)):
417-
end = delims[i]
418-
if i == 0:
419-
stream = np.vstack((leftover, coords[begin:end]))
420-
else:
421-
stream = coords[begin:end]
422-
leftover = np.empty((0,3), dtype='<f4')
423-
yield stream
424-
n_streams += 1
425-
426-
begin = end+1 #skip the delimiter
424+
for delim in delims:
425+
pts = coords[begin:delim]
426+
if pts.size:
427+
yield coords[begin:delim]
428+
n_streams += 1
429+
begin = delim + 1
427430

428431
# the rest gets appended to the leftover
429-
leftover = np.vstack((leftover, coords[begin:]))
432+
leftover = coords[begin:]
430433

431-
if not np.all(np.isinf(leftover), axis=1):
434+
if not (leftover.shape == (1, 3) and np.isinf(leftover).all()):
432435
raise DataError("Expecting end-of-file marker 'inf inf inf'")
433436

434437
# In case the 'count' field was not provided.

nibabel/streamlines/tests/test_tck.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ def setup():
3232
"simple_big_endian.tck")
3333
# standard.tck contains only streamlines
3434
DATA['standard_tck_fname'] = pjoin(data_path, "standard.tck")
35+
DATA['matlab_nan_tck_fname'] = pjoin(data_path, "matlab_nan.tck")
3536

3637
DATA['streamlines'] = [np.arange(1 * 3, dtype="f4").reshape((1, 3)),
3738
np.arange(2 * 3, dtype="f4").reshape((2, 3)),
@@ -64,6 +65,13 @@ def test_load_simple_file(self):
6465
tck = TckFile(tractogram, header=hdr)
6566
assert_tractogram_equal(tck.tractogram, DATA['simple_tractogram'])
6667

68+
def test_load_matlab_nan_file(self):
69+
for lazy_load in [False, True]:
70+
tck = TckFile.load(DATA['matlab_nan_tck_fname'], lazy_load=lazy_load)
71+
streamlines = list(tck.tractogram.streamlines)
72+
assert_equal(len(streamlines), 1)
73+
assert_equal(streamlines[0].shape, (108, 3))
74+
6775
def test_writeable_data(self):
6876
data = DATA['simple_tractogram']
6977
for key in ('simple_tck_fname', 'simple_tck_big_endian_fname'):

0 commit comments

Comments
 (0)