File tree Expand file tree Collapse file tree 2 files changed +44
-1
lines changed Expand file tree Collapse file tree 2 files changed +44
-1
lines changed Original file line number Diff line number Diff line change @@ -380,3 +380,31 @@ def create_arraysequences_from_generator(gen, n):
380380 for seq in seqs :
381381 seq .finalize_append ()
382382 return seqs
383+
384+
385+ def concatenate (seqs , axis ):
386+ """ Concatenates multiple :class:`ArraySequence` objects along an axis.
387+
388+ Parameters
389+ ----------
390+ seqs: iterable of :class:`ArraySequence` objects
391+ Sequences to concatenate.
392+ axis : int
393+ Axis along which the sequences will be concatenated.
394+
395+ Returns
396+ -------
397+ new_seq: :class:`ArraySequence` object
398+ New :class:`ArraySequence` object which is the result of
399+ concatenating multiple sequences along the given axis.
400+ """
401+ new_seq = seqs [0 ].copy ()
402+ if axis == 0 :
403+ # This is the same as an extend.
404+ for seq in seqs [1 :]:
405+ new_seq .extend (seq )
406+
407+ return new_seq
408+
409+ new_seq ._data = np .concatenate ([seq ._data for seq in seqs ], axis = axis )
410+ return new_seq
Original file line number Diff line number Diff line change 88from nibabel .testing import assert_arrays_equal
99from numpy .testing import assert_array_equal
1010
11- from ..array_sequence import ArraySequence , is_array_sequence
11+ from ..array_sequence import ArraySequence , is_array_sequence , concatenate
1212
1313
1414SEQ_DATA = {}
@@ -299,3 +299,18 @@ def test_save_and_load_arraysequence(self):
299299
300300 # Make sure we can add new elements to it.
301301 loaded_seq .append (SEQ_DATA ['data' ][0 ])
302+
303+
304+ def test_concatenate ():
305+ seq = SEQ_DATA ['seq' ].copy () # In case there is in-place modification.
306+ seqs = [seq [:, [i ]] for i in range (seq .common_shape [0 ])]
307+ new_seq = concatenate (seqs , axis = 1 )
308+ seq ._data += 100 # Modifying the 'seq' shouldn't change 'new_seq'.
309+ check_arr_seq (new_seq , SEQ_DATA ['data' ])
310+ assert_true (not new_seq ._is_view )
311+
312+ seq = SEQ_DATA ['seq' ]
313+ seqs = [seq [:, [i ]] for i in range (seq .common_shape [0 ])]
314+ new_seq = concatenate (seqs , axis = 0 )
315+ assert_true (len (new_seq ), seq .common_shape [0 ] * len (seq ))
316+ assert_array_equal (new_seq ._data , seq ._data .T .reshape ((- 1 , 1 )))
You can’t perform that action at this time.
0 commit comments