Skip to content

Commit 7044ee4

Browse files
Merge pull request #494 from MarcCote/enh_concatenate_array_sequences
MRG: function to concatenate multiple ArraySequence objects Adds the function `nibabel.streamlines.array_sequence.concatenate(seqs, axis)` that can concatenate multiple `ArraySequence` objects along the provided axis.
2 parents 5d94c73 + 09f5de5 commit 7044ee4

File tree

2 files changed

+44
-1
lines changed

2 files changed

+44
-1
lines changed

nibabel/streamlines/array_sequence.py

+28
Original file line numberDiff line numberDiff line change
@@ -380,3 +380,31 @@ def create_arraysequences_from_generator(gen, n):
380380
for seq in seqs:
381381
seq.finalize_append()
382382
return seqs
383+
384+
385+
def concatenate(seqs, axis):
386+
""" Concatenates multiple :class:`ArraySequence` objects along an axis.
387+
388+
Parameters
389+
----------
390+
seqs: iterable of :class:`ArraySequence` objects
391+
Sequences to concatenate.
392+
axis : int
393+
Axis along which the sequences will be concatenated.
394+
395+
Returns
396+
-------
397+
new_seq: :class:`ArraySequence` object
398+
New :class:`ArraySequence` object which is the result of
399+
concatenating multiple sequences along the given axis.
400+
"""
401+
new_seq = seqs[0].copy()
402+
if axis == 0:
403+
# This is the same as an extend.
404+
for seq in seqs[1:]:
405+
new_seq.extend(seq)
406+
407+
return new_seq
408+
409+
new_seq._data = np.concatenate([seq._data for seq in seqs], axis=axis)
410+
return new_seq

nibabel/streamlines/tests/test_array_sequence.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from nibabel.testing import assert_arrays_equal
99
from numpy.testing import assert_array_equal
1010

11-
from ..array_sequence import ArraySequence, is_array_sequence
11+
from ..array_sequence import ArraySequence, is_array_sequence, concatenate
1212

1313

1414
SEQ_DATA = {}
@@ -299,3 +299,18 @@ def test_save_and_load_arraysequence(self):
299299

300300
# Make sure we can add new elements to it.
301301
loaded_seq.append(SEQ_DATA['data'][0])
302+
303+
304+
def test_concatenate():
305+
seq = SEQ_DATA['seq'].copy() # In case there is in-place modification.
306+
seqs = [seq[:, [i]] for i in range(seq.common_shape[0])]
307+
new_seq = concatenate(seqs, axis=1)
308+
seq._data += 100 # Modifying the 'seq' shouldn't change 'new_seq'.
309+
check_arr_seq(new_seq, SEQ_DATA['data'])
310+
assert_true(not new_seq._is_view)
311+
312+
seq = SEQ_DATA['seq']
313+
seqs = [seq[:, [i]] for i in range(seq.common_shape[0])]
314+
new_seq = concatenate(seqs, axis=0)
315+
assert_true(len(new_seq), seq.common_shape[0] * len(seq))
316+
assert_array_equal(new_seq._data, seq._data.T.reshape((-1, 1)))

0 commit comments

Comments
 (0)