Skip to content

Commit f788cb9

Browse files
committed
Add a function to concatenate multiple ArraySequences object given an axis.
1 parent ec4567f commit f788cb9

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ def __getitem__(self, idx):
318318
raise TypeError("Index must be either an int, a slice, a list of int"
319319
" or a ndarray of bool! Not " + str(type(idx)))
320320

321+
321322
def __iter__(self):
322323
if len(self._lengths) != len(self._offsets):
323324
raise ValueError("ArraySequence object corrupted:"
@@ -380,3 +381,31 @@ def create_arraysequences_from_generator(gen, n):
380381
for seq in seqs:
381382
seq.finalize_append()
382383
return seqs
384+
385+
386+
def concatenate(seqs, axis):
387+
""" Concatenates multiple :class:`ArraySequence` objects along an axis.
388+
389+
Parameters
390+
----------
391+
seqs: list of :class:`ArraySequence` objects
392+
Sequences to concatenate.
393+
axis : int
394+
Axis along which the sequences will be concatenated.
395+
396+
Returns
397+
-------
398+
new_seq: :class:`ArraySequence` object
399+
New :class:`ArraySequence` object which is the result of
400+
concatenating multiple sequences along the given axis.
401+
"""
402+
new_seq = seqs[0].copy()
403+
if axis == 0:
404+
# This is the same as an extend.
405+
for seq in seqs[1:]:
406+
new_seq.extend(seq)
407+
408+
return new_seq
409+
410+
new_seq._data = np.concatenate([seq._data for seq in seqs], axis=axis)
411+
return new_seq

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from nibabel.testing import assert_arrays_equal
99
from numpy.testing import assert_array_equal
1010

11-
from ..array_sequence import ArraySequence, is_array_sequence
11+
from ..array_sequence import ArraySequence, is_array_sequence, concatenate, create_arraysequences_from_generator
1212

1313

1414
SEQ_DATA = {}
@@ -299,3 +299,18 @@ def test_save_and_load_arraysequence(self):
299299

300300
# Make sure we can add new elements to it.
301301
loaded_seq.append(SEQ_DATA['data'][0])
302+
303+
304+
def test_concatenate():
305+
seq = SEQ_DATA['seq'].copy() # In case there is in-place modification.
306+
seqs = [seq[:, [i]] for i in range(seq.common_shape[0])]
307+
new_seq = concatenate(seqs, axis=1)
308+
check_arr_seq(new_seq, SEQ_DATA['data'])
309+
assert_true(not new_seq._is_view)
310+
311+
seq = SEQ_DATA['seq'].copy() # In case there is in-place modification.
312+
seqs = [seq[:, [i]] for i in range(seq.common_shape[0])]
313+
new_seq = concatenate(seqs, axis=0)
314+
315+
assert_true(len(new_seq), seq.common_shape[0]*len(seq))
316+
assert_array_equal(new_seq._data, seq._data.T.reshape((-1, 1)))

0 commit comments

Comments
 (0)