diff --git a/nibabel/streamlines/array_sequence.py b/nibabel/streamlines/array_sequence.py index 94c8aaf004..d892b9b91d 100644 --- a/nibabel/streamlines/array_sequence.py +++ b/nibabel/streamlines/array_sequence.py @@ -380,3 +380,31 @@ def create_arraysequences_from_generator(gen, n): for seq in seqs: seq.finalize_append() return seqs + + +def concatenate(seqs, axis): + """ Concatenates multiple :class:`ArraySequence` objects along an axis. + + Parameters + ---------- + seqs: iterable of :class:`ArraySequence` objects + Sequences to concatenate. + axis : int + Axis along which the sequences will be concatenated. + + Returns + ------- + new_seq: :class:`ArraySequence` object + New :class:`ArraySequence` object which is the result of + concatenating multiple sequences along the given axis. + """ + new_seq = seqs[0].copy() + if axis == 0: + # This is the same as an extend. + for seq in seqs[1:]: + new_seq.extend(seq) + + return new_seq + + new_seq._data = np.concatenate([seq._data for seq in seqs], axis=axis) + return new_seq diff --git a/nibabel/streamlines/tests/test_array_sequence.py b/nibabel/streamlines/tests/test_array_sequence.py index a2ebd3a22e..42bd6ba49a 100644 --- a/nibabel/streamlines/tests/test_array_sequence.py +++ b/nibabel/streamlines/tests/test_array_sequence.py @@ -8,7 +8,7 @@ from nibabel.testing import assert_arrays_equal from numpy.testing import assert_array_equal -from ..array_sequence import ArraySequence, is_array_sequence +from ..array_sequence import ArraySequence, is_array_sequence, concatenate SEQ_DATA = {} @@ -299,3 +299,18 @@ def test_save_and_load_arraysequence(self): # Make sure we can add new elements to it. loaded_seq.append(SEQ_DATA['data'][0]) + + +def test_concatenate(): + seq = SEQ_DATA['seq'].copy() # In case there is in-place modification. + seqs = [seq[:, [i]] for i in range(seq.common_shape[0])] + new_seq = concatenate(seqs, axis=1) + seq._data += 100 # Modifying the 'seq' shouldn't change 'new_seq'. + check_arr_seq(new_seq, SEQ_DATA['data']) + assert_true(not new_seq._is_view) + + seq = SEQ_DATA['seq'] + seqs = [seq[:, [i]] for i in range(seq.common_shape[0])] + new_seq = concatenate(seqs, axis=0) + assert_true(len(new_seq), seq.common_shape[0] * len(seq)) + assert_array_equal(new_seq._data, seq._data.T.reshape((-1, 1)))